MLIR  22.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Iterators.h"
17 #include "mlir/IR/Operation.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "llvm/Support/SaveAndRestore.h"
26 #include "llvm/Support/ScopedPrinter.h"
27 #include <optional>
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 #define DEBUG_TYPE "dialect-conversion"
33 
34 /// A utility function to log a successful result for the given reason.
35 template <typename... Args>
36 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
37  LLVM_DEBUG({
38  os.unindent();
39  os.startLine() << "} -> SUCCESS";
40  if (!fmt.empty())
41  os.getOStream() << " : "
42  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
43  os.getOStream() << "\n";
44  });
45 }
46 
47 /// A utility function to log a failure result for the given reason.
48 template <typename... Args>
49 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
50  LLVM_DEBUG({
51  os.unindent();
52  os.startLine() << "} -> FAILURE : "
53  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
54  << "\n";
55  });
56 }
57 
58 /// Helper function that computes an insertion point where the given value is
59 /// defined and can be used without a dominance violation.
61  Block *insertBlock = value.getParentBlock();
62  Block::iterator insertPt = insertBlock->begin();
63  if (OpResult inputRes = dyn_cast<OpResult>(value))
64  insertPt = ++inputRes.getOwner()->getIterator();
65  return OpBuilder::InsertPoint(insertBlock, insertPt);
66 }
67 
68 /// Helper function that computes an insertion point where the given values are
69 /// defined and can be used without a dominance violation.
71  assert(!vals.empty() && "expected at least one value");
72  DominanceInfo domInfo;
73  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
74  for (Value v : vals.drop_front()) {
75  // Choose the "later" insertion point.
77  if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
78  nextPt.getPoint())) {
79  // pt is before nextPt => choose nextPt.
80  pt = nextPt;
81  } else {
82 #ifndef NDEBUG
83  // nextPt should be before pt => choose pt.
84  // If pt, nextPt are no dominance relationship, then there is no valid
85  // insertion point at which all given values are defined.
86  bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
87  pt.getBlock(), pt.getPoint());
88  assert(dom && "unable to find valid insertion point");
89 #endif // NDEBUG
90  }
91  }
92  return pt;
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // ConversionValueMapping
97 //===----------------------------------------------------------------------===//
98 
99 /// A vector of SSA values, optimized for the most common case of a single
100 /// value.
101 using ValueVector = SmallVector<Value, 1>;
102 
103 namespace {
104 
105 /// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
106 struct ValueVectorMapInfo {
107  static ValueVector getEmptyKey() { return ValueVector{Value()}; }
108  static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
109  static ::llvm::hash_code getHashValue(const ValueVector &val) {
110  return ::llvm::hash_combine_range(val);
111  }
112  static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
113  return LHS == RHS;
114  }
115 };
116 
117 /// This class wraps a IRMapping to provide recursive lookup
118 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
119 struct ConversionValueMapping {
120  /// Return "true" if an SSA value is mapped to the given value. May return
121  /// false positives.
122  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
123 
124  /// Lookup a value in the mapping.
125  ValueVector lookup(const ValueVector &from) const;
126 
127  template <typename T>
128  struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
129 
130  /// Map a value vector to the one provided.
131  template <typename OldVal, typename NewVal>
132  std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
133  map(OldVal &&oldVal, NewVal &&newVal) {
134  LLVM_DEBUG({
135  ValueVector next(newVal);
136  while (true) {
137  assert(next != oldVal && "inserting cyclic mapping");
138  auto it = mapping.find(next);
139  if (it == mapping.end())
140  break;
141  next = it->second;
142  }
143  });
144  mappedTo.insert_range(newVal);
145 
146  mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
147  }
148 
149  /// Map a value vector or single value to the one provided.
150  template <typename OldVal, typename NewVal>
151  std::enable_if_t<!IsValueVector<OldVal>::value ||
152  !IsValueVector<NewVal>::value>
153  map(OldVal &&oldVal, NewVal &&newVal) {
154  if constexpr (IsValueVector<OldVal>{}) {
155  map(std::forward<OldVal>(oldVal), ValueVector{newVal});
156  } else if constexpr (IsValueVector<NewVal>{}) {
157  map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
158  } else {
159  map(ValueVector{oldVal}, ValueVector{newVal});
160  }
161  }
162 
163  void map(Value oldVal, SmallVector<Value> &&newVal) {
164  map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
165  }
166 
167  /// Drop the last mapping for the given values.
168  void erase(const ValueVector &value) { mapping.erase(value); }
169 
170 private:
171  /// Current value mappings.
173 
174  /// All SSA values that are mapped to. May contain false positives.
175  DenseSet<Value> mappedTo;
176 };
177 } // namespace
178 
179 /// Marker attribute for pure type conversions. I.e., mappings whose only
180 /// purpose is to resolve a type mismatch. (In contrast, mappings that point to
181 /// the replacement values of a "replaceOp" call, etc., are not pure type
182 /// conversions.)
183 static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
184 
185 /// Return the operation that defines all values in the vector. Return nullptr
186 /// if the values are not defined by the same operation.
187 static Operation *getCommonDefiningOp(const ValueVector &values) {
188  assert(!values.empty() && "expected non-empty value vector");
189  Operation *op = values.front().getDefiningOp();
190  for (Value v : llvm::drop_begin(values)) {
191  if (v.getDefiningOp() != op)
192  return nullptr;
193  }
194  return op;
195 }
196 
197 /// A vector of values is a pure type conversion if all values are defined by
198 /// the same operation and the operation has the `kPureTypeConversionMarker`
199 /// attribute.
200 static bool isPureTypeConversion(const ValueVector &values) {
201  assert(!values.empty() && "expected non-empty value vector");
202  Operation *op = getCommonDefiningOp(values);
203  return op && op->hasAttr(kPureTypeConversionMarker);
204 }
205 
206 ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
207  auto it = mapping.find(from);
208  if (it == mapping.end()) {
209  // No mapping found: The lookup stops here.
210  return {};
211  }
212  return it->second;
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // Rewriter and Translation State
217 //===----------------------------------------------------------------------===//
218 namespace {
219 /// This class contains a snapshot of the current conversion rewriter state.
220 /// This is useful when saving and undoing a set of rewrites.
221 struct RewriterState {
222  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
223  unsigned numReplacedOps)
224  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
225  numReplacedOps(numReplacedOps) {}
226 
227  /// The current number of rewrites performed.
228  unsigned numRewrites;
229 
230  /// The current number of ignored operations.
231  unsigned numIgnoredOperations;
232 
233  /// The current number of replaced ops that are scheduled for erasure.
234  unsigned numReplacedOps;
235 };
236 
237 //===----------------------------------------------------------------------===//
238 // IR rewrites
239 //===----------------------------------------------------------------------===//
240 
241 static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
242 
243 /// Notify the listener that the given block and its contents are being erased.
244 static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
245  for (Operation &op : b)
246  notifyIRErased(listener, op);
247  listener->notifyBlockErased(&b);
248 }
249 
250 /// Notify the listener that the given operation and its contents are being
251 /// erased.
252 static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
253  for (Region &r : op.getRegions()) {
254  for (Block &b : r) {
255  notifyIRErased(listener, b);
256  }
257  }
258  listener->notifyOperationErased(&op);
259 }
260 
261 /// An IR rewrite that can be committed (upon success) or rolled back (upon
262 /// failure).
263 ///
264 /// The dialect conversion keeps track of IR modifications (requested by the
265 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
266 /// are directly applied to the IR as the rewriter API is used, some are applied
267 /// partially, and some are delayed until the `IRRewrite` objects are committed.
268 class IRRewrite {
269 public:
270  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
271  /// Enum values are ordered, so that they can be used in `classof`: first all
272  /// block rewrites, then all operation rewrites.
273  enum class Kind {
274  // Block rewrites
275  CreateBlock,
276  EraseBlock,
277  InlineBlock,
278  MoveBlock,
279  BlockTypeConversion,
280  // Operation rewrites
281  MoveOperation,
282  ModifyOperation,
283  ReplaceOperation,
284  CreateOperation,
285  UnresolvedMaterialization,
286  // Value rewrites
287  ReplaceValue
288  };
289 
290  virtual ~IRRewrite() = default;
291 
292  /// Roll back the rewrite. Operations may be erased during rollback.
293  virtual void rollback() = 0;
294 
295  /// Commit the rewrite. At this point, it is certain that the dialect
296  /// conversion will succeed. All IR modifications, except for operation/block
297  /// erasure, must be performed through the given rewriter.
298  ///
299  /// Instead of erasing operations/blocks, they should merely be unlinked
300  /// commit phase and finally be erased during the cleanup phase. This is
301  /// because internal dialect conversion state (such as `mapping`) may still
302  /// be using them.
303  ///
304  /// Any IR modification that was already performed before the commit phase
305  /// (e.g., insertion of an op) must be communicated to the listener that may
306  /// be attached to the given rewriter.
307  virtual void commit(RewriterBase &rewriter) {}
308 
309  /// Cleanup operations/blocks. Cleanup is called after commit.
310  virtual void cleanup(RewriterBase &rewriter) {}
311 
312  Kind getKind() const { return kind; }
313 
314  static bool classof(const IRRewrite *rewrite) { return true; }
315 
316 protected:
317  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
318  : kind(kind), rewriterImpl(rewriterImpl) {}
319 
320  const ConversionConfig &getConfig() const;
321 
322  const Kind kind;
323  ConversionPatternRewriterImpl &rewriterImpl;
324 };
325 
326 /// A block rewrite.
327 class BlockRewrite : public IRRewrite {
328 public:
329  /// Return the block that this rewrite operates on.
330  Block *getBlock() const { return block; }
331 
332  static bool classof(const IRRewrite *rewrite) {
333  return rewrite->getKind() >= Kind::CreateBlock &&
334  rewrite->getKind() <= Kind::BlockTypeConversion;
335  }
336 
337 protected:
338  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
339  Block *block)
340  : IRRewrite(kind, rewriterImpl), block(block) {}
341 
342  // The block that this rewrite operates on.
343  Block *block;
344 };
345 
346 /// A value rewrite.
347 class ValueRewrite : public IRRewrite {
348 public:
349  /// Return the value that this rewrite operates on.
350  Value getValue() const { return value; }
351 
352  static bool classof(const IRRewrite *rewrite) {
353  return rewrite->getKind() == Kind::ReplaceValue;
354  }
355 
356 protected:
357  ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
358  Value value)
359  : IRRewrite(kind, rewriterImpl), value(value) {}
360 
361  // The value that this rewrite operates on.
362  Value value;
363 };
364 
365 /// Creation of a block. Block creations are immediately reflected in the IR.
366 /// There is no extra work to commit the rewrite. During rollback, the newly
367 /// created block is erased.
368 class CreateBlockRewrite : public BlockRewrite {
369 public:
370  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
371  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
372 
373  static bool classof(const IRRewrite *rewrite) {
374  return rewrite->getKind() == Kind::CreateBlock;
375  }
376 
377  void commit(RewriterBase &rewriter) override {
378  // The block was already created and inserted. Just inform the listener.
379  if (auto *listener = rewriter.getListener())
380  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
381  }
382 
383  void rollback() override {
384  // Unlink all of the operations within this block, they will be deleted
385  // separately.
386  auto &blockOps = block->getOperations();
387  while (!blockOps.empty())
388  blockOps.remove(blockOps.begin());
389  block->dropAllUses();
390  if (block->getParent())
391  block->erase();
392  else
393  delete block;
394  }
395 };
396 
397 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
398 /// blocks are immediately unlinked, but only erased during cleanup. This makes
399 /// it easier to rollback a block erasure: the block is simply inserted into its
400 /// original location.
401 class EraseBlockRewrite : public BlockRewrite {
402 public:
403  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
404  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
405  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
406 
407  static bool classof(const IRRewrite *rewrite) {
408  return rewrite->getKind() == Kind::EraseBlock;
409  }
410 
411  ~EraseBlockRewrite() override {
412  assert(!block &&
413  "rewrite was neither rolled back nor committed/cleaned up");
414  }
415 
416  void rollback() override {
417  // The block (owned by this rewrite) was not actually erased yet. It was
418  // just unlinked. Put it back into its original position.
419  assert(block && "expected block");
420  auto &blockList = region->getBlocks();
421  Region::iterator before = insertBeforeBlock
422  ? Region::iterator(insertBeforeBlock)
423  : blockList.end();
424  blockList.insert(before, block);
425  block = nullptr;
426  }
427 
428  void commit(RewriterBase &rewriter) override {
429  assert(block && "expected block");
430 
431  // Notify the listener that the block and its contents are being erased.
432  if (auto *listener =
433  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
434  notifyIRErased(listener, *block);
435  }
436 
437  void cleanup(RewriterBase &rewriter) override {
438  // Erase the contents of the block.
439  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
440  rewriter.eraseOp(&op);
441  assert(block->empty() && "expected empty block");
442 
443  // Erase the block.
444  block->dropAllDefinedValueUses();
445  delete block;
446  block = nullptr;
447  }
448 
449 private:
450  // The region in which this block was previously contained.
451  Region *region;
452 
453  // The original successor of this block before it was unlinked. "nullptr" if
454  // this block was the only block in the region.
455  Block *insertBeforeBlock;
456 };
457 
458 /// Inlining of a block. This rewrite is immediately reflected in the IR.
459 /// Note: This rewrite represents only the inlining of the operations. The
460 /// erasure of the inlined block is a separate rewrite.
461 class InlineBlockRewrite : public BlockRewrite {
462 public:
463  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
464  Block *sourceBlock, Block::iterator before)
465  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
466  sourceBlock(sourceBlock),
467  firstInlinedInst(sourceBlock->empty() ? nullptr
468  : &sourceBlock->front()),
469  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
470  // If a listener is attached to the dialect conversion, ops must be moved
471  // one-by-one. When they are moved in bulk, notifications cannot be sent
472  // because the ops that used to be in the source block at the time of the
473  // inlining (before the "commit" phase) are unknown at the time when
474  // notifications are sent (which is during the "commit" phase).
475  assert(!getConfig().listener &&
476  "InlineBlockRewrite not supported if listener is attached");
477  }
478 
479  static bool classof(const IRRewrite *rewrite) {
480  return rewrite->getKind() == Kind::InlineBlock;
481  }
482 
483  void rollback() override {
484  // Put the operations from the destination block (owned by the rewrite)
485  // back into the source block.
486  if (firstInlinedInst) {
487  assert(lastInlinedInst && "expected operation");
488  sourceBlock->getOperations().splice(sourceBlock->begin(),
489  block->getOperations(),
490  Block::iterator(firstInlinedInst),
491  ++Block::iterator(lastInlinedInst));
492  }
493  }
494 
495 private:
496  // The block that originally contained the operations.
497  Block *sourceBlock;
498 
499  // The first inlined operation.
500  Operation *firstInlinedInst;
501 
502  // The last inlined operation.
503  Operation *lastInlinedInst;
504 };
505 
506 /// Moving of a block. This rewrite is immediately reflected in the IR.
507 class MoveBlockRewrite : public BlockRewrite {
508 public:
509  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
510  Region *previousRegion, Region::iterator previousIt)
511  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
512  region(previousRegion),
513  insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
514  : &*previousIt) {}
515 
516  static bool classof(const IRRewrite *rewrite) {
517  return rewrite->getKind() == Kind::MoveBlock;
518  }
519 
520  void commit(RewriterBase &rewriter) override {
521  // The block was already moved. Just inform the listener.
522  if (auto *listener = rewriter.getListener()) {
523  // Note: `previousIt` cannot be passed because this is a delayed
524  // notification and iterators into past IR state cannot be represented.
525  listener->notifyBlockInserted(block, /*previous=*/region,
526  /*previousIt=*/{});
527  }
528  }
529 
530  void rollback() override {
531  // Move the block back to its original position.
532  Region::iterator before =
533  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
534  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
535  }
536 
537 private:
538  // The region in which this block was previously contained.
539  Region *region;
540 
541  // The original successor of this block before it was moved. "nullptr" if
542  // this block was the only block in the region.
543  Block *insertBeforeBlock;
544 };
545 
546 /// Block type conversion. This rewrite is partially reflected in the IR.
547 class BlockTypeConversionRewrite : public BlockRewrite {
548 public:
549  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
550  Block *origBlock, Block *newBlock)
551  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
552  newBlock(newBlock) {}
553 
554  static bool classof(const IRRewrite *rewrite) {
555  return rewrite->getKind() == Kind::BlockTypeConversion;
556  }
557 
558  Block *getOrigBlock() const { return block; }
559 
560  Block *getNewBlock() const { return newBlock; }
561 
562  void commit(RewriterBase &rewriter) override;
563 
564  void rollback() override;
565 
566 private:
567  /// The new block that was created as part of this signature conversion.
568  Block *newBlock;
569 };
570 
571 /// Replacing a value. This rewrite is not immediately reflected in the
572 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
573 /// until the rewrite is committed.
574 class ReplaceValueRewrite : public ValueRewrite {
575 public:
576  ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
577  const TypeConverter *converter)
578  : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
579  converter(converter) {}
580 
581  static bool classof(const IRRewrite *rewrite) {
582  return rewrite->getKind() == Kind::ReplaceValue;
583  }
584 
585  void commit(RewriterBase &rewriter) override;
586 
587  void rollback() override;
588 
589 private:
590  /// The current type converter when the value was replaced.
591  const TypeConverter *converter;
592 };
593 
594 /// An operation rewrite.
595 class OperationRewrite : public IRRewrite {
596 public:
597  /// Return the operation that this rewrite operates on.
598  Operation *getOperation() const { return op; }
599 
600  static bool classof(const IRRewrite *rewrite) {
601  return rewrite->getKind() >= Kind::MoveOperation &&
602  rewrite->getKind() <= Kind::UnresolvedMaterialization;
603  }
604 
605 protected:
606  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
607  Operation *op)
608  : IRRewrite(kind, rewriterImpl), op(op) {}
609 
610  // The operation that this rewrite operates on.
611  Operation *op;
612 };
613 
614 /// Moving of an operation. This rewrite is immediately reflected in the IR.
615 class MoveOperationRewrite : public OperationRewrite {
616 public:
617  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
618  Operation *op, OpBuilder::InsertPoint previous)
619  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
620  block(previous.getBlock()),
621  insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
622  ? nullptr
623  : &*previous.getPoint()) {}
624 
625  static bool classof(const IRRewrite *rewrite) {
626  return rewrite->getKind() == Kind::MoveOperation;
627  }
628 
629  void commit(RewriterBase &rewriter) override {
630  // The operation was already moved. Just inform the listener.
631  if (auto *listener = rewriter.getListener()) {
632  // Note: `previousIt` cannot be passed because this is a delayed
633  // notification and iterators into past IR state cannot be represented.
634  listener->notifyOperationInserted(
635  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
636  /*insertPt=*/{}));
637  }
638  }
639 
640  void rollback() override {
641  // Move the operation back to its original position.
642  Block::iterator before =
643  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
644  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
645  }
646 
647 private:
648  // The block in which this operation was previously contained.
649  Block *block;
650 
651  // The original successor of this operation before it was moved. "nullptr"
652  // if this operation was the only operation in the region.
653  Operation *insertBeforeOp;
654 };
655 
656 /// In-place modification of an op. This rewrite is immediately reflected in
657 /// the IR. The previous state of the operation is stored in this object.
658 class ModifyOperationRewrite : public OperationRewrite {
659 public:
660  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
661  Operation *op)
662  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
663  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
664  operands(op->operand_begin(), op->operand_end()),
665  successors(op->successor_begin(), op->successor_end()) {
666  if (OpaqueProperties prop = op->getPropertiesStorage()) {
667  // Make a copy of the properties.
668  propertiesStorage = operator new(op->getPropertiesStorageSize());
669  OpaqueProperties propCopy(propertiesStorage);
670  name.initOpProperties(propCopy, /*init=*/prop);
671  }
672  }
673 
674  static bool classof(const IRRewrite *rewrite) {
675  return rewrite->getKind() == Kind::ModifyOperation;
676  }
677 
678  ~ModifyOperationRewrite() override {
679  assert(!propertiesStorage &&
680  "rewrite was neither committed nor rolled back");
681  }
682 
683  void commit(RewriterBase &rewriter) override {
684  // Notify the listener that the operation was modified in-place.
685  if (auto *listener =
686  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
687  listener->notifyOperationModified(op);
688 
689  if (propertiesStorage) {
690  OpaqueProperties propCopy(propertiesStorage);
691  // Note: The operation may have been erased in the mean time, so
692  // OperationName must be stored in this object.
693  name.destroyOpProperties(propCopy);
694  operator delete(propertiesStorage);
695  propertiesStorage = nullptr;
696  }
697  }
698 
699  void rollback() override {
700  op->setLoc(loc);
701  op->setAttrs(attrs);
702  op->setOperands(operands);
703  for (const auto &it : llvm::enumerate(successors))
704  op->setSuccessor(it.value(), it.index());
705  if (propertiesStorage) {
706  OpaqueProperties propCopy(propertiesStorage);
707  op->copyProperties(propCopy);
708  name.destroyOpProperties(propCopy);
709  operator delete(propertiesStorage);
710  propertiesStorage = nullptr;
711  }
712  }
713 
714 private:
715  OperationName name;
716  LocationAttr loc;
717  DictionaryAttr attrs;
718  SmallVector<Value, 8> operands;
719  SmallVector<Block *, 2> successors;
720  void *propertiesStorage = nullptr;
721 };
722 
723 /// Replacing an operation. Erasing an operation is treated as a special case
724 /// with "null" replacements. This rewrite is not immediately reflected in the
725 /// IR. An internal IR mapping is updated, but values are not replaced and the
726 /// original op is not erased until the rewrite is committed.
727 class ReplaceOperationRewrite : public OperationRewrite {
728 public:
729  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
730  Operation *op, const TypeConverter *converter)
731  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
732  converter(converter) {}
733 
734  static bool classof(const IRRewrite *rewrite) {
735  return rewrite->getKind() == Kind::ReplaceOperation;
736  }
737 
738  void commit(RewriterBase &rewriter) override;
739 
740  void rollback() override;
741 
742  void cleanup(RewriterBase &rewriter) override;
743 
744 private:
745  /// An optional type converter that can be used to materialize conversions
746  /// between the new and old values if necessary.
747  const TypeConverter *converter;
748 };
749 
750 class CreateOperationRewrite : public OperationRewrite {
751 public:
752  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
753  Operation *op)
754  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
755 
756  static bool classof(const IRRewrite *rewrite) {
757  return rewrite->getKind() == Kind::CreateOperation;
758  }
759 
760  void commit(RewriterBase &rewriter) override {
761  // The operation was already created and inserted. Just inform the listener.
762  if (auto *listener = rewriter.getListener())
763  listener->notifyOperationInserted(op, /*previous=*/{});
764  }
765 
766  void rollback() override;
767 };
768 
769 /// The type of materialization.
770 enum MaterializationKind {
771  /// This materialization materializes a conversion from an illegal type to a
772  /// legal one.
773  Target,
774 
775  /// This materialization materializes a conversion from a legal type back to
776  /// an illegal one.
777  Source
778 };
779 
780 /// Helper class that stores metadata about an unresolved materialization.
781 class UnresolvedMaterializationInfo {
782 public:
783  UnresolvedMaterializationInfo() = default;
784  UnresolvedMaterializationInfo(const TypeConverter *converter,
785  MaterializationKind kind, Type originalType)
786  : converterAndKind(converter, kind), originalType(originalType) {}
787 
788  /// Return the type converter of this materialization (which may be null).
789  const TypeConverter *getConverter() const {
790  return converterAndKind.getPointer();
791  }
792 
793  /// Return the kind of this materialization.
794  MaterializationKind getMaterializationKind() const {
795  return converterAndKind.getInt();
796  }
797 
798  /// Return the original type of the SSA value.
799  Type getOriginalType() const { return originalType; }
800 
801 private:
802  /// The corresponding type converter to use when resolving this
803  /// materialization, and the kind of this materialization.
804  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
805  converterAndKind;
806 
807  /// The original type of the SSA value. Only used for target
808  /// materializations.
809  Type originalType;
810 };
811 
812 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
813 /// op. Unresolved materializations fold away or are replaced with
814 /// source/target materializations at the end of the dialect conversion.
815 class UnresolvedMaterializationRewrite : public OperationRewrite {
816 public:
817  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
818  UnrealizedConversionCastOp op,
819  ValueVector mappedValues)
820  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
821  mappedValues(std::move(mappedValues)) {}
822 
823  static bool classof(const IRRewrite *rewrite) {
824  return rewrite->getKind() == Kind::UnresolvedMaterialization;
825  }
826 
827  void rollback() override;
828 
829  UnrealizedConversionCastOp getOperation() const {
830  return cast<UnrealizedConversionCastOp>(op);
831  }
832 
833 private:
834  /// The values in the conversion value mapping that are being replaced by the
835  /// results of this unresolved materialization.
836  ValueVector mappedValues;
837 };
838 } // namespace
839 
840 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
841 /// Return "true" if there is an operation rewrite that matches the specified
842 /// rewrite type and operation among the given rewrites.
843 template <typename RewriteTy, typename R>
844 static bool hasRewrite(R &&rewrites, Operation *op) {
845  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
846  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
847  return rewriteTy && rewriteTy->getOperation() == op;
848  });
849 }
850 
851 /// Return "true" if there is a block rewrite that matches the specified
852 /// rewrite type and block among the given rewrites.
853 template <typename RewriteTy, typename R>
854 static bool hasRewrite(R &&rewrites, Block *block) {
855  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
856  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
857  return rewriteTy && rewriteTy->getBlock() == block;
858  });
859 }
860 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
861 
862 //===----------------------------------------------------------------------===//
863 // ConversionPatternRewriterImpl
864 //===----------------------------------------------------------------------===//
865 namespace mlir {
866 namespace detail {
869  const ConversionConfig &config)
870  : rewriter(rewriter), config(config),
871  notifyingRewriter(rewriter.getContext(), config.listener) {}
872 
873  //===--------------------------------------------------------------------===//
874  // State Management
875  //===--------------------------------------------------------------------===//
876 
877  /// Return the current state of the rewriter.
878  RewriterState getCurrentState();
879 
880  /// Apply all requested operation rewrites. This method is invoked when the
881  /// conversion process succeeds.
882  void applyRewrites();
883 
884  /// Reset the state of the rewriter to a previously saved point. Optionally,
885  /// the name of the pattern that triggered the rollback can specified for
886  /// debugging purposes.
887  void resetState(RewriterState state, StringRef patternName = "");
888 
889  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
890  /// failure.
891  template <typename RewriteTy, typename... Args>
892  void appendRewrite(Args &&...args) {
893  assert(config.allowPatternRollback && "appending rewrites is not allowed");
894  rewrites.push_back(
895  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
896  }
897 
898  /// Undo the rewrites (motions, splits) one by one in reverse order until
899  /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
900  /// that triggered the rollback can specified for debugging purposes.
901  void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
902 
903  /// Remap the given values to those with potentially different types. Returns
904  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
905  /// is the tag used when describing a value within a diagnostic, e.g.
906  /// "operand".
907  LogicalResult remapValues(StringRef valueDiagTag,
908  std::optional<Location> inputLoc, ValueRange values,
909  SmallVector<ValueVector> &remapped);
910 
911  /// Return "true" if the given operation is ignored, and does not need to be
912  /// converted.
913  bool isOpIgnored(Operation *op) const;
914 
915  /// Return "true" if the given operation was replaced or erased.
916  bool wasOpReplaced(Operation *op) const;
917 
918  /// Lookup the most recently mapped values with the desired types in the
919  /// mapping, taking into account only replacements. Perform a best-effort
920  /// search for existing materializations with the desired types.
921  ///
922  /// If `skipPureTypeConversions` is "true", materializations that are pure
923  /// type conversions are not considered.
924  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
925  bool skipPureTypeConversions = false) const;
926 
927  /// Lookup the given value within the map, or return an empty vector if the
928  /// value is not mapped. If it is mapped, this follows the same behavior
929  /// as `lookupOrDefault`.
930  ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
931 
932  //===--------------------------------------------------------------------===//
933  // IR Rewrites / Type Conversion
934  //===--------------------------------------------------------------------===//
935 
936  /// Convert the types of block arguments within the given region.
937  FailureOr<Block *>
938  convertRegionTypes(Region *region, const TypeConverter &converter,
939  TypeConverter::SignatureConversion *entryConversion);
940 
941  /// Apply the given signature conversion on the given block. The new block
942  /// containing the updated signature is returned. If no conversions were
943  /// necessary, e.g. if the block has no arguments, `block` is returned.
944  /// `converter` is used to generate any necessary cast operations that
945  /// translate between the origin argument types and those specified in the
946  /// signature conversion.
947  Block *applySignatureConversion(
948  Block *block, const TypeConverter *converter,
949  TypeConverter::SignatureConversion &signatureConversion);
950 
951  /// Replace the results of the given operation with the given values and
952  /// erase the operation.
953  ///
954  /// There can be multiple replacement values for each result (1:N
955  /// replacement). If the replacement values are empty, the respective result
956  /// is dropped and a source materialization is built if the result still has
957  /// uses.
958  void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
959 
960  /// Replace the uses of the given value with the given values. The specified
961  /// converter is used to build materializations (if necessary).
962  void replaceAllUsesWith(Value from, ValueRange to,
963  const TypeConverter *converter);
964 
965  /// Erase the given block and its contents.
966  void eraseBlock(Block *block);
967 
968  /// Inline the source block into the destination block before the given
969  /// iterator.
970  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
971 
972  //===--------------------------------------------------------------------===//
973  // Materializations
974  //===--------------------------------------------------------------------===//
975 
976  /// Build an unresolved materialization operation given a range of output
977  /// types and a list of input operands. Returns the inputs if they their
978  /// types match the output types.
979  ///
980  /// If a cast op was built, it can optionally be returned with the `castOp`
981  /// output argument.
982  ///
983  /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
984  /// the results of the unresolved materialization in the conversion value
985  /// mapping.
986  ///
987  /// If `isPureTypeConversion` is "true", the materialization is created only
988  /// to resolve a type mismatch. That means it is not a regular value
989  /// replacement issued by the user. (Replacement values that are created
990  /// "out of thin air" appear like unresolved materializations because they are
991  /// unrealized_conversion_cast ops. However, they must be treated like
992  /// regular value replacements.)
993  ValueRange buildUnresolvedMaterialization(
994  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
995  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
996  Type originalType, const TypeConverter *converter,
997  bool isPureTypeConversion = true);
998 
999  /// Find a replacement value for the given SSA value in the conversion value
1000  /// mapping. The replacement value must have the same type as the given SSA
1001  /// value. If there is no replacement value with the correct type, find the
1002  /// latest replacement value (regardless of the type) and build a source
1003  /// materialization.
1004  Value findOrBuildReplacementValue(Value value,
1005  const TypeConverter *converter);
1006 
1007  //===--------------------------------------------------------------------===//
1008  // Rewriter Notification Hooks
1009  //===--------------------------------------------------------------------===//
1010 
1011  //// Notifies that an op was inserted.
1012  void notifyOperationInserted(Operation *op,
1013  OpBuilder::InsertPoint previous) override;
1014 
1015  /// Notifies that a block was inserted.
1016  void notifyBlockInserted(Block *block, Region *previous,
1017  Region::iterator previousIt) override;
1018 
1019  /// Notifies that a pattern match failed for the given reason.
1020  void
1021  notifyMatchFailure(Location loc,
1022  function_ref<void(Diagnostic &)> reasonCallback) override;
1023 
1024  //===--------------------------------------------------------------------===//
1025  // IR Erasure
1026  //===--------------------------------------------------------------------===//
1027 
1028  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1029  /// operation or block is erased multiple times. This rewriter assumes that
1030  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1032  public:
1034  MLIRContext *context,
1035  std::function<void(Operation *)> opErasedCallback = nullptr)
1036  : RewriterBase(context, /*listener=*/this),
1037  opErasedCallback(opErasedCallback) {}
1038 
1039  /// Erase the given op (unless it was already erased).
1040  void eraseOp(Operation *op) override {
1041  if (wasErased(op))
1042  return;
1043  op->dropAllUses();
1045  }
1046 
1047  /// Erase the given block (unless it was already erased).
1048  void eraseBlock(Block *block) override {
1049  if (wasErased(block))
1050  return;
1051  assert(block->empty() && "expected empty block");
1052  block->dropAllDefinedValueUses();
1053  RewriterBase::eraseBlock(block);
1054  }
1055 
1056  bool wasErased(void *ptr) const { return erased.contains(ptr); }
1057 
1058  void notifyOperationErased(Operation *op) override {
1059  erased.insert(op);
1060  if (opErasedCallback)
1061  opErasedCallback(op);
1062  }
1063 
1064  void notifyBlockErased(Block *block) override { erased.insert(block); }
1065 
1066  private:
1067  /// Pointers to all erased operations and blocks.
1068  DenseSet<void *> erased;
1069 
1070  /// A callback that is invoked when an operation is erased.
1071  std::function<void(Operation *)> opErasedCallback;
1072  };
1073 
1074  //===--------------------------------------------------------------------===//
1075  // State
1076  //===--------------------------------------------------------------------===//
1077 
1078  /// The rewriter that is used to perform the conversion.
1080 
1081  // Mapping between replaced values that differ in type. This happens when
1082  // replacing a value with one of a different type.
1083  ConversionValueMapping mapping;
1084 
1085  /// Ordered list of block operations (creations, splits, motions).
1086  /// This vector is maintained only if `allowPatternRollback` is set to
1087  /// "true". Otherwise, all IR rewrites are materialized immediately and no
1088  /// bookkeeping is needed.
1090 
1091  /// A set of operations that should no longer be considered for legalization.
1092  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1093  /// tracked separately.
1095 
1096  /// A set of operations that were replaced/erased. Such ops are not erased
1097  /// immediately but only when the dialect conversion succeeds. In the mean
1098  /// time, they should no longer be considered for legalization and any attempt
1099  /// to modify/access them is invalid rewriter API usage.
1101 
1102  /// A set of operations that were created by the current pattern.
1104 
1105  /// A set of operations that were modified by the current pattern.
1107 
1108  /// A set of blocks that were inserted (newly-created blocks or moved blocks)
1109  /// by the current pattern.
1111 
1112  /// A list of unresolved materializations that were created by the current
1113  /// pattern.
1115 
1116  /// A mapping for looking up metadata of unresolved materializations.
1119 
1120  /// The current type converter, or nullptr if no type converter is currently
1121  /// active.
1122  const TypeConverter *currentTypeConverter = nullptr;
1123 
1124  /// A mapping of regions to type converters that should be used when
1125  /// converting the arguments of blocks within that region.
1127 
1128  /// Dialect conversion configuration.
1130 
1131  /// A set of erased operations. This set is utilized only if
1132  /// `allowPatternRollback` is set to "false". Conceptually, this set is
1133  /// similar to `replacedOps` (which is maintained when the flag is set to
1134  /// "true"). However, erasing from a DenseSet is more efficient than erasing
1135  /// from a SetVector.
1137 
1138  /// A set of erased blocks. This set is utilized only if
1139  /// `allowPatternRollback` is set to "false".
1141 
1142  /// A rewriter that notifies the listener (if any) about all IR
1143  /// modifications. This rewriter is utilized only if `allowPatternRollback`
1144  /// is set to "false". If the flag is set to "true", the listener is notified
1145  /// with a separate mechanism (e.g., in `IRRewrite::commit`).
1147 
1148 #ifndef NDEBUG
1149  /// A set of replaced values. This set is for debugging purposes only and it
1150  /// is maintained only if `allowPatternRollback` is set to "true".
1152 
1153  /// A set of operations that have pending updates. This tracking isn't
1154  /// strictly necessary, and is thus only active during debug builds for extra
1155  /// verification.
1157 
1158  /// A raw output stream used to prefix the debug log.
1159  llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
1160  llvm::dbgs()};
1161 
1162  /// A logger used to emit diagnostics during the conversion process.
1163  llvm::ScopedPrinter logger{os};
1164  std::string logPrefix;
1165 #endif
1166 };
1167 } // namespace detail
1168 } // namespace mlir
1169 
1170 const ConversionConfig &IRRewrite::getConfig() const {
1171  return rewriterImpl.config;
1172 }
1173 
1174 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1175  // Inform the listener about all IR modifications that have already taken
1176  // place: References to the original block have been replaced with the new
1177  // block.
1178  if (auto *listener =
1179  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1180  for (Operation *op : getNewBlock()->getUsers())
1181  listener->notifyOperationModified(op);
1182 }
1183 
1184 void BlockTypeConversionRewrite::rollback() {
1185  getNewBlock()->replaceAllUsesWith(getOrigBlock());
1186 }
1187 
1188 /// Replace all uses of `from` with `repl`.
1189 static void performReplaceValue(RewriterBase &rewriter, Value from,
1190  Value repl) {
1191  if (isa<BlockArgument>(repl)) {
1192  // `repl` is a block argument. Directly replace all uses.
1193  rewriter.replaceAllUsesWith(from, repl);
1194  return;
1195  }
1196 
1197  // If the replacement value is an operation, only replace those uses that:
1198  // - are in a different block than the replacement operation, or
1199  // - are in the same block but after the replacement operation.
1200  //
1201  // Example:
1202  // ^bb0(%arg0: i32):
1203  // %0 = "consumer"(%arg0) : (i32) -> (i32)
1204  // "another_consumer"(%arg0) : (i32) -> ()
1205  //
1206  // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1207  // use in "another_consumer" but not the use in "consumer". When using the
1208  // normal RewriterBase API, this would typically be done with
1209  // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1210  // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1211  // it cannot be supported efficiently with `allowPatternRollback` set to
1212  // "true". Therefore, the conversion driver is trying to be smart and replaces
1213  // only those uses that do not lead to a dominance violation. E.g., the
1214  // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1215  // behavior.
1216  //
1217  // TODO: As we move more and more towards `allowPatternRollback` set to
1218  // "false", we should remove this special handling, in order to align the
1219  // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1220  Operation *replOp = repl.getDefiningOp();
1221  Block *replBlock = replOp->getBlock();
1222  rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
1223  Operation *user = operand.getOwner();
1224  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1225  });
1226 }
1227 
1228 void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1229  Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
1230  if (!repl)
1231  return;
1232  performReplaceValue(rewriter, value, repl);
1233 }
1234 
1235 void ReplaceValueRewrite::rollback() {
1236  rewriterImpl.mapping.erase({value});
1237 #ifndef NDEBUG
1238  rewriterImpl.replacedValues.erase(value);
1239 #endif // NDEBUG
1240 }
1241 
1242 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1243  auto *listener =
1244  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1245 
1246  // Compute replacement values.
1247  SmallVector<Value> replacements =
1248  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1249  return rewriterImpl.findOrBuildReplacementValue(result, converter);
1250  });
1251 
1252  // Notify the listener that the operation is about to be replaced.
1253  if (listener)
1254  listener->notifyOperationReplaced(op, replacements);
1255 
1256  // Replace all uses with the new values.
1257  for (auto [result, newValue] :
1258  llvm::zip_equal(op->getResults(), replacements))
1259  if (newValue)
1260  rewriter.replaceAllUsesWith(result, newValue);
1261 
1262  // The original op will be erased, so remove it from the set of unlegalized
1263  // ops.
1264  if (getConfig().unlegalizedOps)
1265  getConfig().unlegalizedOps->erase(op);
1266 
1267  // Notify the listener that the operation and its contents are being erased.
1268  if (listener)
1269  notifyIRErased(listener, *op);
1270 
1271  // Do not erase the operation yet. It may still be referenced in `mapping`.
1272  // Just unlink it for now and erase it during cleanup.
1273  op->getBlock()->getOperations().remove(op);
1274 }
1275 
1276 void ReplaceOperationRewrite::rollback() {
1277  for (auto result : op->getResults())
1278  rewriterImpl.mapping.erase({result});
1279 }
1280 
1281 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1282  rewriter.eraseOp(op);
1283 }
1284 
1285 void CreateOperationRewrite::rollback() {
1286  for (Region &region : op->getRegions()) {
1287  while (!region.getBlocks().empty())
1288  region.getBlocks().remove(region.getBlocks().begin());
1289  }
1290  op->dropAllUses();
1291  op->erase();
1292 }
1293 
1294 void UnresolvedMaterializationRewrite::rollback() {
1295  if (!mappedValues.empty())
1296  rewriterImpl.mapping.erase(mappedValues);
1297  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1298  op->erase();
1299 }
1300 
1302  // Commit all rewrites. Use a new rewriter, so the modifications are not
1303  // tracked for rollback purposes etc.
1304  IRRewriter irRewriter(rewriter.getContext(), config.listener);
1305  // Note: New rewrites may be added during the "commit" phase and the
1306  // `rewrites` vector may reallocate.
1307  for (size_t i = 0; i < rewrites.size(); ++i)
1308  rewrites[i]->commit(irRewriter);
1309 
1310  // Clean up all rewrites.
1311  SingleEraseRewriter eraseRewriter(
1312  rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
1313  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1314  unresolvedMaterializations.erase(castOp);
1315  });
1316  for (auto &rewrite : rewrites)
1317  rewrite->cleanup(eraseRewriter);
1318 }
1319 
1320 //===----------------------------------------------------------------------===//
1321 // State Management
1322 //===----------------------------------------------------------------------===//
1323 
1325  Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1326  // Helper function that looks up a single value.
1327  auto lookup = [&](const ValueVector &values) -> ValueVector {
1328  assert(!values.empty() && "expected non-empty value vector");
1329 
1330  // If the pattern rollback is enabled, use the mapping to look up the
1331  // values.
1333  return mapping.lookup(values);
1334 
1335  // Otherwise, look up values by examining the IR. All replacements have
1336  // already been materialized in IR.
1337  Operation *op = getCommonDefiningOp(values);
1338  if (!op)
1339  return {};
1340  auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1341  if (!castOp)
1342  return {};
1343  if (!this->unresolvedMaterializations.contains(castOp))
1344  return {};
1345  if (castOp.getOutputs() != values)
1346  return {};
1347  return castOp.getInputs();
1348  };
1349 
1350  // Helper function that looks up each value in `values` individually and then
1351  // composes the results. If that fails, it tries to look up the entire vector
1352  // at once.
1353  auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1354  // If possible, replace each value with (one or multiple) mapped values.
1355  ValueVector next;
1356  for (Value v : values) {
1357  ValueVector r = lookup({v});
1358  if (!r.empty()) {
1359  llvm::append_range(next, r);
1360  } else {
1361  next.push_back(v);
1362  }
1363  }
1364  if (next != values) {
1365  // At least one value was replaced.
1366  return next;
1367  }
1368 
1369  // Otherwise: Check if there is a mapping for the entire vector. Such
1370  // mappings are materializations. (N:M mapping are not supported for value
1371  // replacements.)
1372  //
1373  // Note: From a correctness point of view, materializations do not have to
1374  // be stored (and looked up) in the mapping. But for performance reasons,
1375  // we choose to reuse existing IR (when possible) instead of creating it
1376  // multiple times.
1377  ValueVector r = lookup(values);
1378  if (r.empty()) {
1379  // No mapping found: The lookup stops here.
1380  return {};
1381  }
1382  return r;
1383  };
1384 
1385  // Try to find the deepest values that have the desired types. If there is no
1386  // such mapping, simply return the deepest values.
1387  ValueVector desiredValue;
1388  ValueVector current{from};
1389  ValueVector lastNonMaterialization{from};
1390  do {
1391  // Store the current value if the types match.
1392  bool match = TypeRange(ValueRange(current)) == desiredTypes;
1393  if (skipPureTypeConversions) {
1394  // Skip pure type conversions, if requested.
1395  bool pureConversion = isPureTypeConversion(current);
1396  match &= !pureConversion;
1397  // Keep track of the last mapped value that was not a pure type
1398  // conversion.
1399  if (!pureConversion)
1400  lastNonMaterialization = current;
1401  }
1402  if (match)
1403  desiredValue = current;
1404 
1405  // Lookup next value in the mapping.
1406  ValueVector next = composedLookup(current);
1407  if (next.empty())
1408  break;
1409  current = std::move(next);
1410  } while (true);
1411 
1412  // If the desired values were found use them, otherwise default to the leaf
1413  // values. (Skip pure type conversions, if requested.)
1414  if (!desiredTypes.empty())
1415  return desiredValue;
1416  if (skipPureTypeConversions)
1417  return lastNonMaterialization;
1418  return current;
1419 }
1420 
1423  TypeRange desiredTypes) const {
1424  ValueVector result = lookupOrDefault(from, desiredTypes);
1425  if (result == ValueVector{from} ||
1426  (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
1427  return {};
1428  return result;
1429 }
1430 
1432  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1433 }
1434 
1436  StringRef patternName) {
1437  // Undo any rewrites.
1438  undoRewrites(state.numRewrites, patternName);
1439 
1440  // Pop all of the recorded ignored operations that are no longer valid.
1441  while (ignoredOps.size() != state.numIgnoredOperations)
1442  ignoredOps.pop_back();
1443 
1444  while (replacedOps.size() != state.numReplacedOps)
1445  replacedOps.pop_back();
1446 }
1447 
1448 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1449  StringRef patternName) {
1450  for (auto &rewrite :
1451  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1452  rewrite->rollback();
1453  rewrites.resize(numRewritesToKeep);
1454 }
1455 
1457  StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1458  SmallVector<ValueVector> &remapped) {
1459  remapped.reserve(llvm::size(values));
1460 
1461  for (const auto &it : llvm::enumerate(values)) {
1462  Value operand = it.value();
1463  Type origType = operand.getType();
1464  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1465 
1466  if (!currentTypeConverter) {
1467  // The current pattern does not have a type converter. Pass the most
1468  // recently mapped values, excluding materializations. Materializations
1469  // are intentionally excluded because their presence may depend on other
1470  // patterns. Including materializations would make the lookup fragile
1471  // and unpredictable.
1472  remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1473  /*skipPureTypeConversions=*/true));
1474  continue;
1475  }
1476 
1477  // If there is no legal conversion, fail to match this pattern.
1478  SmallVector<Type, 1> legalTypes;
1479  if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
1480  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1481  diag << "unable to convert type for " << valueDiagTag << " #"
1482  << it.index() << ", type was " << origType;
1483  });
1484  return failure();
1485  }
1486  // If a type is converted to 0 types, there is nothing to do.
1487  if (legalTypes.empty()) {
1488  remapped.push_back({});
1489  continue;
1490  }
1491 
1492  ValueVector repl = lookupOrDefault(operand, legalTypes);
1493  if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1494  // Mapped values have the correct type or there is an existing
1495  // materialization. Or the operand is not mapped at all and has the
1496  // correct type.
1497  remapped.push_back(std::move(repl));
1498  continue;
1499  }
1500 
1501  // Create a materialization for the most recently mapped values.
1502  repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1503  /*skipPureTypeConversions=*/true);
1505  MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1506  /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1507  /*originalType=*/origType, currentTypeConverter);
1508  remapped.push_back(castValues);
1509  }
1510  return success();
1511 }
1512 
1514  // Check to see if this operation is ignored or was replaced.
1515  return wasOpReplaced(op) || ignoredOps.count(op);
1516 }
1517 
1519  // Check to see if this operation was replaced.
1520  return replacedOps.count(op) || erasedOps.count(op);
1521 }
1522 
1523 //===----------------------------------------------------------------------===//
1524 // Type Conversion
1525 //===----------------------------------------------------------------------===//
1526 
1528  Region *region, const TypeConverter &converter,
1529  TypeConverter::SignatureConversion *entryConversion) {
1530  regionToConverter[region] = &converter;
1531  if (region->empty())
1532  return nullptr;
1533 
1534  // Convert the arguments of each non-entry block within the region.
1535  for (Block &block :
1536  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1537  // Compute the signature for the block with the provided converter.
1538  std::optional<TypeConverter::SignatureConversion> conversion =
1539  converter.convertBlockSignature(&block);
1540  if (!conversion)
1541  return failure();
1542  // Convert the block with the computed signature.
1543  applySignatureConversion(&block, &converter, *conversion);
1544  }
1545 
1546  // Convert the entry block. If an entry signature conversion was provided,
1547  // use that one. Otherwise, compute the signature with the type converter.
1548  if (entryConversion)
1549  return applySignatureConversion(&region->front(), &converter,
1550  *entryConversion);
1551  std::optional<TypeConverter::SignatureConversion> conversion =
1552  converter.convertBlockSignature(&region->front());
1553  if (!conversion)
1554  return failure();
1555  return applySignatureConversion(&region->front(), &converter, *conversion);
1556 }
1557 
1559  Block *block, const TypeConverter *converter,
1560  TypeConverter::SignatureConversion &signatureConversion) {
1561 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1562  // A block cannot be converted multiple times.
1563  if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1564  llvm::report_fatal_error("block was already converted");
1565 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1566 
1568 
1569  // If no arguments are being changed or added, there is nothing to do.
1570  unsigned origArgCount = block->getNumArguments();
1571  auto convertedTypes = signatureConversion.getConvertedTypes();
1572  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1573  return block;
1574 
1575  // Compute the locations of all block arguments in the new block.
1576  SmallVector<Location> newLocs(convertedTypes.size(),
1578  for (unsigned i = 0; i < origArgCount; ++i) {
1579  auto inputMap = signatureConversion.getInputMapping(i);
1580  if (!inputMap || inputMap->replacedWithValues())
1581  continue;
1582  Location origLoc = block->getArgument(i).getLoc();
1583  for (unsigned j = 0; j < inputMap->size; ++j)
1584  newLocs[inputMap->inputNo + j] = origLoc;
1585  }
1586 
1587  // Insert a new block with the converted block argument types and move all ops
1588  // from the old block to the new block.
1589  Block *newBlock =
1590  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1591  convertedTypes, newLocs);
1592 
1593  // If a listener is attached to the dialect conversion, ops cannot be moved
1594  // to the destination block in bulk ("fast path"). This is because at the time
1595  // the notifications are sent, it is unknown which ops were moved. Instead,
1596  // ops should be moved one-by-one ("slow path"), so that a separate
1597  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1598  // a bit more efficient, so we try to do that when possible.
1599  bool fastPath = !config.listener;
1600  if (fastPath) {
1602  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1603  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1604  } else {
1605  while (!block->empty())
1606  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1607  }
1608 
1609  // Replace all uses of the old block with the new block.
1610  block->replaceAllUsesWith(newBlock);
1611 
1612  for (unsigned i = 0; i != origArgCount; ++i) {
1613  BlockArgument origArg = block->getArgument(i);
1614  Type origArgType = origArg.getType();
1615 
1616  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1617  signatureConversion.getInputMapping(i);
1618  if (!inputMap) {
1619  // This block argument was dropped and no replacement value was provided.
1620  // Materialize a replacement value "out of thin air".
1621  // Note: Materialization must be built here because we cannot find a
1622  // valid insertion point in the new block. (Will point to the old block.)
1623  Value mat =
1625  MaterializationKind::Source,
1626  OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1627  origArg.getLoc(),
1628  /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1629  /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1630  /*isPureTypeConversion=*/false)
1631  .front();
1632  replaceAllUsesWith(origArg, mat, converter);
1633  continue;
1634  }
1635 
1636  if (inputMap->replacedWithValues()) {
1637  // This block argument was dropped and replacement values were provided.
1638  assert(inputMap->size == 0 &&
1639  "invalid to provide a replacement value when the argument isn't "
1640  "dropped");
1641  replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
1642  continue;
1643  }
1644 
1645  // This is a 1->1+ mapping.
1646  auto replArgs =
1647  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1648  replaceAllUsesWith(origArg, replArgs, converter);
1649  }
1650 
1652  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1653 
1654  // Erase the old block. (It is just unlinked for now and will be erased during
1655  // cleanup.)
1656  rewriter.eraseBlock(block);
1657 
1658  return newBlock;
1659 }
1660 
1661 //===----------------------------------------------------------------------===//
1662 // Materializations
1663 //===----------------------------------------------------------------------===//
1664 
1665 /// Build an unresolved materialization operation given an output type and set
1666 /// of input operands.
1668  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1669  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1670  Type originalType, const TypeConverter *converter,
1671  bool isPureTypeConversion) {
1672  assert((!originalType || kind == MaterializationKind::Target) &&
1673  "original type is valid only for target materializations");
1674  assert(TypeRange(inputs) != outputTypes &&
1675  "materialization is not necessary");
1676 
1677  // Create an unresolved materialization. We use a new OpBuilder to avoid
1678  // tracking the materialization like we do for other operations.
1679  OpBuilder builder(outputTypes.front().getContext());
1680  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1681  UnrealizedConversionCastOp convertOp =
1682  UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1684  StringRef kindStr =
1685  kind == MaterializationKind::Source ? "source" : "target";
1686  convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1687  }
1689  convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
1690 
1691  // Register the materialization.
1692  unresolvedMaterializations[convertOp] =
1693  UnresolvedMaterializationInfo(converter, kind, originalType);
1695  if (!valuesToMap.empty())
1696  mapping.map(valuesToMap, convertOp.getResults());
1697  appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1698  std::move(valuesToMap));
1699  } else {
1700  patternMaterializations.insert(convertOp);
1701  }
1702  return convertOp.getResults();
1703 }
1704 
1706  Value value, const TypeConverter *converter) {
1707  assert(config.allowPatternRollback &&
1708  "this code path is valid only in rollback mode");
1709 
1710  // Try to find a replacement value with the same type in the conversion value
1711  // mapping. This includes cached materializations. We try to reuse those
1712  // instead of generating duplicate IR.
1713  ValueVector repl = lookupOrNull(value, value.getType());
1714  if (!repl.empty())
1715  return repl.front();
1716 
1717  // Check if the value is dead. No replacement value is needed in that case.
1718  // This is an approximate check that may have false negatives but does not
1719  // require computing and traversing an inverse mapping. (We may end up
1720  // building source materializations that are never used and that fold away.)
1721  if (llvm::all_of(value.getUsers(),
1722  [&](Operation *op) { return replacedOps.contains(op); }) &&
1723  !mapping.isMappedTo(value))
1724  return Value();
1725 
1726  // No replacement value was found. Get the latest replacement value
1727  // (regardless of the type) and build a source materialization to the
1728  // original type.
1729  repl = lookupOrNull(value);
1730 
1731  // Compute the insertion point of the materialization.
1733  if (repl.empty()) {
1734  // The source materialization has no inputs. Insert it right before the
1735  // value that it is replacing.
1736  ip = computeInsertPoint(value);
1737  } else {
1738  // Compute the "earliest" insertion point at which all values in `repl` are
1739  // defined. It is important to emit the materialization at that location
1740  // because the same materialization may be reused in a different context.
1741  // (That's because materializations are cached in the conversion value
1742  // mapping.) The insertion point of the materialization must be valid for
1743  // all future users that may be created later in the conversion process.
1744  ip = computeInsertPoint(repl);
1745  }
1747  MaterializationKind::Source, ip, value.getLoc(),
1748  /*valuesToMap=*/repl, /*inputs=*/repl,
1749  /*outputTypes=*/value.getType(),
1750  /*originalType=*/Type(), converter,
1751  /*isPureTypeConversion=*/!repl.empty())
1752  .front();
1753  return castValue;
1754 }
1755 
1756 //===----------------------------------------------------------------------===//
1757 // Rewriter Notification Hooks
1758 //===----------------------------------------------------------------------===//
1759 
1761  Operation *op, OpBuilder::InsertPoint previous) {
1762  // If no previous insertion point is provided, the op used to be detached.
1763  bool wasDetached = !previous.isSet();
1764  LLVM_DEBUG({
1765  logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
1766  << ")";
1767  if (wasDetached)
1768  logger.getOStream() << " (was detached)";
1769  logger.getOStream() << "\n";
1770  });
1771 
1772  // In rollback mode, it is easier to misuse the API, so perform extra error
1773  // checking.
1774  assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
1775  "attempting to insert into a block within a replaced/erased op");
1776 
1777  // In "no rollback" mode, the listener is always notified immediately.
1779  config.listener->notifyOperationInserted(op, previous);
1780 
1781  if (wasDetached) {
1782  // If the op was detached, it is most likely a newly created op. Add it the
1783  // set of newly created ops, so that it will be legalized. If this op is
1784  // not a newly created op, it will be legalized a second time, which is
1785  // inefficient but harmless.
1786  patternNewOps.insert(op);
1787 
1789  // TODO: If the same op is inserted multiple times from a detached
1790  // state, the rollback mechanism may erase the same op multiple times.
1791  // This is a bug in the rollback-based dialect conversion driver.
1792  appendRewrite<CreateOperationRewrite>(op);
1793  } else {
1794  // In "no rollback" mode, there is an extra data structure for tracking
1795  // erased operations that must be kept up to date.
1796  erasedOps.erase(op);
1797  }
1798  return;
1799  }
1800 
1801  // The op was moved from one place to another.
1803  appendRewrite<MoveOperationRewrite>(op, previous);
1804 }
1805 
1806 /// Given that `fromRange` is about to be replaced with `toRange`, compute
1807 /// replacement values with the types of `fromRange`.
1808 static SmallVector<Value>
1810  const SmallVector<SmallVector<Value>> &toRange,
1811  const TypeConverter *converter) {
1812  assert(!impl.config.allowPatternRollback &&
1813  "this code path is valid only in 'no rollback' mode");
1814  SmallVector<Value> repls;
1815  for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1816  if (from.use_empty()) {
1817  // The replaced value is dead. No replacement value is needed.
1818  repls.push_back(Value());
1819  continue;
1820  }
1821 
1822  if (to.empty()) {
1823  // The replaced value is dropped. Materialize a replacement value "out of
1824  // thin air".
1825  Value srcMat = impl.buildUnresolvedMaterialization(
1826  MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
1827  /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1828  /*outputTypes=*/from.getType(), /*originalType=*/Type(),
1829  converter)[0];
1830  repls.push_back(srcMat);
1831  continue;
1832  }
1833 
1834  if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
1835  // The replacement value already has the correct type. Use it directly.
1836  repls.push_back(to[0]);
1837  continue;
1838  }
1839 
1840  // The replacement value has the wrong type. Build a source materialization
1841  // to the original type.
1842  // TODO: This is a bit inefficient. We should try to reuse existing
1843  // materializations if possible. This would require an extension of the
1844  // `lookupOrDefault` API.
1845  Value srcMat = impl.buildUnresolvedMaterialization(
1846  MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
1847  /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
1848  /*originalType=*/Type(), converter)[0];
1849  repls.push_back(srcMat);
1850  }
1851 
1852  return repls;
1853 }
1854 
1856  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1857  assert(newValues.size() == op->getNumResults() &&
1858  "incorrect number of replacement values");
1859  LLVM_DEBUG({
1860  logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
1861  << ")\n";
1862  if (currentTypeConverter) {
1863  // If the user-provided replacement types are different from the
1864  // legalized types, as per the current type converter, print a note.
1865  // In most cases, the replacement types are expected to match the types
1866  // produced by the type converter, so this could indicate a bug in the
1867  // user code.
1868  for (auto [result, repls] :
1869  llvm::zip_equal(op->getResults(), newValues)) {
1870  Type resultType = result.getType();
1871  auto logProlog = [&, repls = repls]() {
1872  logger.startLine() << " Note: Replacing op result of type "
1873  << resultType << " with value(s) of type (";
1874  llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
1875  logger.getOStream() << v.getType();
1876  });
1877  logger.getOStream() << ")";
1878  };
1879  SmallVector<Type> convertedTypes;
1880  if (failed(currentTypeConverter->convertTypes(resultType,
1881  convertedTypes))) {
1882  logProlog();
1883  logger.getOStream() << ", but the type converter failed to legalize "
1884  "the original type.\n";
1885  continue;
1886  }
1887  if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) {
1888  logProlog();
1889  logger.getOStream() << ", but the legalized type(s) is/are (";
1890  llvm::interleaveComma(convertedTypes, logger.getOStream(),
1891  [&](Type t) { logger.getOStream() << t; });
1892  logger.getOStream() << ")\n";
1893  }
1894  }
1895  }
1896  });
1897 
1898  if (!config.allowPatternRollback) {
1899  // Pattern rollback is not allowed: materialize all IR changes immediately.
1901  *this, op->getResults(), newValues, currentTypeConverter);
1902  // Update internal data structures, so that there are no dangling pointers
1903  // to erased IR.
1904  op->walk([&](Operation *op) {
1905  erasedOps.insert(op);
1906  ignoredOps.remove(op);
1907  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1908  unresolvedMaterializations.erase(castOp);
1909  patternMaterializations.erase(castOp);
1910  }
1911  // The original op will be erased, so remove it from the set of
1912  // unlegalized ops.
1913  if (config.unlegalizedOps)
1914  config.unlegalizedOps->erase(op);
1915  });
1916  op->walk([&](Block *block) { erasedBlocks.insert(block); });
1917  // Replace the op with the replacement values and notify the listener.
1918  notifyingRewriter.replaceOp(op, repls);
1919  return;
1920  }
1921 
1922  assert(!ignoredOps.contains(op) && "operation was already replaced");
1923 #ifndef NDEBUG
1924  for (Value v : op->getResults())
1925  assert(!replacedValues.contains(v) &&
1926  "attempting to replace a value that was already replaced");
1927 #endif // NDEBUG
1928 
1929  // Check if replaced op is an unresolved materialization, i.e., an
1930  // unrealized_conversion_cast op that was created by the conversion driver.
1931  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1932  // Make sure that the user does not mess with unresolved materializations
1933  // that were inserted by the conversion driver. We keep track of these
1934  // ops in internal data structures.
1935  assert(!unresolvedMaterializations.contains(castOp) &&
1936  "attempting to replace/erase an unresolved materialization");
1937  }
1938 
1939  // Create mappings for each of the new result values.
1940  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
1941  mapping.map(static_cast<Value>(result), std::move(repl));
1942 
1943  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1944  // Mark this operation and all nested ops as replaced.
1945  op->walk([&](Operation *op) { replacedOps.insert(op); });
1946 }
1947 
1948 void ConversionPatternRewriterImpl::replaceAllUsesWith(
1949  Value from, ValueRange to, const TypeConverter *converter) {
1950  if (!config.allowPatternRollback) {
1951  SmallVector<Value> toConv = llvm::to_vector(to);
1952  SmallVector<Value> repls =
1953  getReplacementValues(*this, from, {toConv}, converter);
1954  IRRewriter r(from.getContext());
1955  Value repl = repls.front();
1956  if (!repl)
1957  return;
1958 
1959  performReplaceValue(r, from, repl);
1960  return;
1961  }
1962 
1963 #ifndef NDEBUG
1964  // Make sure that a value is not replaced multiple times. In rollback mode,
1965  // `replaceAllUsesWith` replaces not only all current uses of the given value,
1966  // but also all future uses that may be introduced by future pattern
1967  // applications. Therefore, it does not make sense to call
1968  // `replaceAllUsesWith` multiple times with the same value. Doing so would
1969  // overwrite the mapping and mess with the internal state of the dialect
1970  // conversion driver.
1971  assert(!replacedValues.contains(from) &&
1972  "attempting to replace a value that was already replaced");
1973  assert(!wasOpReplaced(from.getDefiningOp()) &&
1974  "attempting to replace a op result that was already replaced");
1975  replacedValues.insert(from);
1976 #endif // NDEBUG
1977 
1978  mapping.map(from, to);
1979  appendRewrite<ReplaceValueRewrite>(from, converter);
1980 }
1981 
1982 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
1983  if (!config.allowPatternRollback) {
1984  // Pattern rollback is not allowed: materialize all IR changes immediately.
1985  // Update internal data structures, so that there are no dangling pointers
1986  // to erased IR.
1987  block->walk([&](Operation *op) {
1988  erasedOps.insert(op);
1989  ignoredOps.remove(op);
1990  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1991  unresolvedMaterializations.erase(castOp);
1992  patternMaterializations.erase(castOp);
1993  }
1994  // The original op will be erased, so remove it from the set of
1995  // unlegalized ops.
1996  if (config.unlegalizedOps)
1997  config.unlegalizedOps->erase(op);
1998  });
1999  block->walk([&](Block *block) { erasedBlocks.insert(block); });
2000  // Erase the block and notify the listener.
2001  notifyingRewriter.eraseBlock(block);
2002  return;
2003  }
2004 
2005  assert(!wasOpReplaced(block->getParentOp()) &&
2006  "attempting to erase a block within a replaced/erased op");
2007  appendRewrite<EraseBlockRewrite>(block);
2008 
2009  // Unlink the block from its parent region. The block is kept in the rewrite
2010  // object and will be actually destroyed when rewrites are applied. This
2011  // allows us to keep the operations in the block live and undo the removal by
2012  // re-inserting the block.
2013  block->getParent()->getBlocks().remove(block);
2014 
2015  // Mark all nested ops as erased.
2016  block->walk([&](Operation *op) { replacedOps.insert(op); });
2017 }
2018 
2019 void ConversionPatternRewriterImpl::notifyBlockInserted(
2020  Block *block, Region *previous, Region::iterator previousIt) {
2021  // If no previous insertion point is provided, the block used to be detached.
2022  bool wasDetached = !previous;
2023  Operation *newParentOp = block->getParentOp();
2024  LLVM_DEBUG(
2025  {
2026  Operation *parent = newParentOp;
2027  if (parent) {
2028  logger.startLine() << "** Insert Block into : '" << parent->getName()
2029  << "' (" << parent << ")";
2030  } else {
2031  logger.startLine()
2032  << "** Insert Block into detached Region (nullptr parent op)";
2033  }
2034  if (wasDetached)
2035  logger.getOStream() << " (was detached)";
2036  logger.getOStream() << "\n";
2037  });
2038 
2039  // In rollback mode, it is easier to misuse the API, so perform extra error
2040  // checking.
2041  assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
2042  "attempting to insert into a region within a replaced/erased op");
2043  (void)newParentOp;
2044 
2045  // In "no rollback" mode, the listener is always notified immediately.
2046  if (!config.allowPatternRollback && config.listener)
2047  config.listener->notifyBlockInserted(block, previous, previousIt);
2048 
2049  patternInsertedBlocks.insert(block);
2050 
2051  if (wasDetached) {
2052  // If the block was detached, it is most likely a newly created block.
2053  if (config.allowPatternRollback) {
2054  // TODO: If the same block is inserted multiple times from a detached
2055  // state, the rollback mechanism may erase the same block multiple times.
2056  // This is a bug in the rollback-based dialect conversion driver.
2057  appendRewrite<CreateBlockRewrite>(block);
2058  } else {
2059  // In "no rollback" mode, there is an extra data structure for tracking
2060  // erased blocks that must be kept up to date.
2061  erasedBlocks.erase(block);
2062  }
2063  return;
2064  }
2065 
2066  // The block was moved from one place to another.
2067  if (config.allowPatternRollback)
2068  appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2069 }
2070 
2071 void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
2072  Block *dest,
2073  Block::iterator before) {
2074  appendRewrite<InlineBlockRewrite>(dest, source, before);
2075 }
2076 
2077 void ConversionPatternRewriterImpl::notifyMatchFailure(
2078  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
2079  LLVM_DEBUG({
2081  reasonCallback(diag);
2082  logger.startLine() << "** Failure : " << diag.str() << "\n";
2083  if (config.notifyCallback)
2084  config.notifyCallback(diag);
2085  });
2086 }
2087 
2088 //===----------------------------------------------------------------------===//
2089 // ConversionPatternRewriter
2090 //===----------------------------------------------------------------------===//
2091 
2092 ConversionPatternRewriter::ConversionPatternRewriter(
2093  MLIRContext *ctx, const ConversionConfig &config)
2094  : PatternRewriter(ctx),
2095  impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
2096  setListener(impl.get());
2097 }
2098 
2100 
2102  return impl->config;
2103 }
2104 
2106  assert(op && newOp && "expected non-null op");
2107  replaceOp(op, newOp->getResults());
2108 }
2109 
2111  assert(op->getNumResults() == newValues.size() &&
2112  "incorrect # of replacement values");
2113 
2114  // If the current insertion point is before the erased operation, we adjust
2115  // the insertion point to be after the operation.
2116  if (getInsertionPoint() == op->getIterator())
2118 
2120  llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2121  return v ? SmallVector<Value>{v} : SmallVector<Value>();
2122  });
2123  impl->replaceOp(op, std::move(newVals));
2124 }
2125 
2127  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2128  assert(op->getNumResults() == newValues.size() &&
2129  "incorrect # of replacement values");
2130 
2131  // If the current insertion point is before the erased operation, we adjust
2132  // the insertion point to be after the operation.
2133  if (getInsertionPoint() == op->getIterator())
2135 
2136  impl->replaceOp(op, std::move(newValues));
2137 }
2138 
2140  LLVM_DEBUG({
2141  impl->logger.startLine()
2142  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
2143  });
2144 
2145  // If the current insertion point is before the erased operation, we adjust
2146  // the insertion point to be after the operation.
2147  if (getInsertionPoint() == op->getIterator())
2149 
2150  SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
2151  impl->replaceOp(op, std::move(nullRepls));
2152 }
2153 
2155  impl->eraseBlock(block);
2156 }
2157 
2159  Block *block, TypeConverter::SignatureConversion &conversion,
2160  const TypeConverter *converter) {
2161  assert(!impl->wasOpReplaced(block->getParentOp()) &&
2162  "attempting to apply a signature conversion to a block within a "
2163  "replaced/erased op");
2164  return impl->applySignatureConversion(block, converter, conversion);
2165 }
2166 
2168  Region *region, const TypeConverter &converter,
2169  TypeConverter::SignatureConversion *entryConversion) {
2170  assert(!impl->wasOpReplaced(region->getParentOp()) &&
2171  "attempting to apply a signature conversion to a block within a "
2172  "replaced/erased op");
2173  return impl->convertRegionTypes(region, converter, entryConversion);
2174 }
2175 
2177  LLVM_DEBUG({
2178  impl->logger.startLine() << "** Replace Value : '" << from << "'";
2179  if (auto blockArg = dyn_cast<BlockArgument>(from)) {
2180  if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
2181  impl->logger.getOStream() << " (in region of '" << parentOp->getName()
2182  << "' (" << parentOp << ")\n";
2183  } else {
2184  impl->logger.getOStream() << " (unlinked block)\n";
2185  }
2186  }
2187  });
2188  impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
2189 }
2190 
2192  SmallVector<ValueVector> remappedValues;
2193  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
2194  remappedValues)))
2195  return nullptr;
2196  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
2197  return remappedValues.front().front();
2198 }
2199 
2200 LogicalResult
2202  SmallVectorImpl<Value> &results) {
2203  if (keys.empty())
2204  return success();
2205  SmallVector<ValueVector> remapped;
2206  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
2207  remapped)))
2208  return failure();
2209  for (const auto &values : remapped) {
2210  assert(values.size() == 1 && "1:N conversion not supported");
2211  results.push_back(values.front());
2212  }
2213  return success();
2214 }
2215 
2217  Block::iterator before,
2218  ValueRange argValues) {
2219 #ifndef NDEBUG
2220  assert(argValues.size() == source->getNumArguments() &&
2221  "incorrect # of argument replacement values");
2222  assert(!impl->wasOpReplaced(source->getParentOp()) &&
2223  "attempting to inline a block from a replaced/erased op");
2224  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
2225  "attempting to inline a block into a replaced/erased op");
2226  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
2227  // The source block will be deleted, so it should not have any users (i.e.,
2228  // there should be no predecessors).
2229  assert(llvm::all_of(source->getUsers(), opIgnored) &&
2230  "expected 'source' to have no predecessors");
2231 #endif // NDEBUG
2232 
2233  // If a listener is attached to the dialect conversion, ops cannot be moved
2234  // to the destination block in bulk ("fast path"). This is because at the time
2235  // the notifications are sent, it is unknown which ops were moved. Instead,
2236  // ops should be moved one-by-one ("slow path"), so that a separate
2237  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
2238  // a bit more efficient, so we try to do that when possible.
2239  bool fastPath = !getConfig().listener;
2240 
2241  if (fastPath && impl->config.allowPatternRollback)
2242  impl->inlineBlockBefore(source, dest, before);
2243 
2244  // Replace all uses of block arguments.
2245  for (auto it : llvm::zip(source->getArguments(), argValues))
2246  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2247 
2248  if (fastPath) {
2249  // Move all ops at once.
2250  dest->getOperations().splice(before, source->getOperations());
2251  } else {
2252  // Move op by op.
2253  while (!source->empty())
2254  moveOpBefore(&source->front(), dest, before);
2255  }
2256 
2257  // If the current insertion point is within the source block, adjust the
2258  // insertion point to the destination block.
2259  if (getInsertionBlock() == source)
2260  setInsertionPoint(dest, getInsertionPoint());
2261 
2262  // Erase the source block.
2263  eraseBlock(source);
2264 }
2265 
2267  if (!impl->config.allowPatternRollback) {
2268  // Pattern rollback is not allowed: no extra bookkeeping is needed.
2270  return;
2271  }
2272  assert(!impl->wasOpReplaced(op) &&
2273  "attempting to modify a replaced/erased op");
2274 #ifndef NDEBUG
2275  impl->pendingRootUpdates.insert(op);
2276 #endif
2277  impl->appendRewrite<ModifyOperationRewrite>(op);
2278 }
2279 
2281  impl->patternModifiedOps.insert(op);
2282  if (!impl->config.allowPatternRollback) {
2284  if (getConfig().listener)
2285  getConfig().listener->notifyOperationModified(op);
2286  return;
2287  }
2288 
2289  // There is nothing to do here, we only need to track the operation at the
2290  // start of the update.
2291 #ifndef NDEBUG
2292  assert(!impl->wasOpReplaced(op) &&
2293  "attempting to modify a replaced/erased op");
2294  assert(impl->pendingRootUpdates.erase(op) &&
2295  "operation did not have a pending in-place update");
2296 #endif
2297 }
2298 
2300  if (!impl->config.allowPatternRollback) {
2302  return;
2303  }
2304 #ifndef NDEBUG
2305  assert(impl->pendingRootUpdates.erase(op) &&
2306  "operation did not have a pending in-place update");
2307 #endif
2308  // Erase the last update for this operation.
2309  auto it = llvm::find_if(
2310  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
2311  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2312  return modifyRewrite && modifyRewrite->getOperation() == op;
2313  });
2314  assert(it != impl->rewrites.rend() && "no root update started on op");
2315  (*it)->rollback();
2316  int updateIdx = std::prev(impl->rewrites.rend()) - it;
2317  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2318 }
2319 
2321  return *impl;
2322 }
2323 
2324 //===----------------------------------------------------------------------===//
2325 // ConversionPattern
2326 //===----------------------------------------------------------------------===//
2327 
2329  ArrayRef<ValueRange> operands) const {
2330  SmallVector<Value> oneToOneOperands;
2331  oneToOneOperands.reserve(operands.size());
2332  for (ValueRange operand : operands) {
2333  if (operand.size() != 1)
2334  return failure();
2335 
2336  oneToOneOperands.push_back(operand.front());
2337  }
2338  return std::move(oneToOneOperands);
2339 }
2340 
2341 LogicalResult
2343  PatternRewriter &rewriter) const {
2344  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
2345  auto &rewriterImpl = dialectRewriter.getImpl();
2346 
2347  // Track the current conversion pattern type converter in the rewriter.
2348  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
2349  getTypeConverter());
2350 
2351  // Remap the operands of the operation.
2352  SmallVector<ValueVector> remapped;
2353  if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
2354  op->getOperands(), remapped))) {
2355  return failure();
2356  }
2357  SmallVector<ValueRange> remappedAsRange =
2358  llvm::to_vector_of<ValueRange>(remapped);
2359  return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2360 }
2361 
2362 //===----------------------------------------------------------------------===//
2363 // OperationLegalizer
2364 //===----------------------------------------------------------------------===//
2365 
2366 namespace {
2367 /// A set of rewrite patterns that can be used to legalize a given operation.
2368 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2369 
2370 /// This class defines a recursive operation legalizer.
2371 class OperationLegalizer {
2372 public:
2373  using LegalizationAction = ConversionTarget::LegalizationAction;
2374 
2375  OperationLegalizer(ConversionPatternRewriter &rewriter,
2376  const ConversionTarget &targetInfo,
2378 
2379  /// Returns true if the given operation is known to be illegal on the target.
2380  bool isIllegal(Operation *op) const;
2381 
2382  /// Attempt to legalize the given operation. Returns success if the operation
2383  /// was legalized, failure otherwise.
2384  LogicalResult legalize(Operation *op);
2385 
2386  /// Returns the conversion target in use by the legalizer.
2387  const ConversionTarget &getTarget() { return target; }
2388 
2389 private:
2390  /// Attempt to legalize the given operation by folding it.
2391  LogicalResult legalizeWithFold(Operation *op);
2392 
2393  /// Attempt to legalize the given operation by applying a pattern. Returns
2394  /// success if the operation was legalized, failure otherwise.
2395  LogicalResult legalizeWithPattern(Operation *op);
2396 
2397  /// Return true if the given pattern may be applied to the given operation,
2398  /// false otherwise.
2399  bool canApplyPattern(Operation *op, const Pattern &pattern);
2400 
2401  /// Legalize the resultant IR after successfully applying the given pattern.
2402  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
2403  const RewriterState &curState,
2404  const SetVector<Operation *> &newOps,
2405  const SetVector<Operation *> &modifiedOps,
2406  const SetVector<Block *> &insertedBlocks);
2407 
2408  /// Legalizes the actions registered during the execution of a pattern.
2409  LogicalResult
2410  legalizePatternBlockRewrites(Operation *op,
2411  const SetVector<Block *> &insertedBlocks,
2412  const SetVector<Operation *> &newOps);
2413  LogicalResult
2414  legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
2415  LogicalResult
2416  legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
2417 
2418  //===--------------------------------------------------------------------===//
2419  // Cost Model
2420  //===--------------------------------------------------------------------===//
2421 
2422  /// Build an optimistic legalization graph given the provided patterns. This
2423  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2424  /// patterns for operations that are not directly legal, but may be
2425  /// transitively legal for the current target given the provided patterns.
2426  void buildLegalizationGraph(
2427  LegalizationPatterns &anyOpLegalizerPatterns,
2429 
2430  /// Compute the benefit of each node within the computed legalization graph.
2431  /// This orders the patterns within 'legalizerPatterns' based upon two
2432  /// criteria:
2433  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2434  /// represent the more direct mapping to the target.
2435  /// 2) When comparing patterns with the same legalization depth, prefer the
2436  /// pattern with the highest PatternBenefit. This allows for users to
2437  /// prefer specific legalizations over others.
2438  void computeLegalizationGraphBenefit(
2439  LegalizationPatterns &anyOpLegalizerPatterns,
2441 
2442  /// Compute the legalization depth when legalizing an operation of the given
2443  /// type.
2444  unsigned computeOpLegalizationDepth(
2445  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2447 
2448  /// Apply the conversion cost model to the given set of patterns, and return
2449  /// the smallest legalization depth of any of the patterns. See
2450  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2451  unsigned applyCostModelToPatterns(
2452  LegalizationPatterns &patterns,
2453  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2455 
2456  /// The current set of patterns that have been applied.
2457  SmallPtrSet<const Pattern *, 8> appliedPatterns;
2458 
2459  /// The rewriter to use when converting operations.
2460  ConversionPatternRewriter &rewriter;
2461 
2462  /// The legalization information provided by the target.
2463  const ConversionTarget &target;
2464 
2465  /// The pattern applicator to use for conversions.
2466  PatternApplicator applicator;
2467 };
2468 } // namespace
2469 
2470 OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2471  const ConversionTarget &targetInfo,
2473  : rewriter(rewriter), target(targetInfo), applicator(patterns) {
2474  // The set of patterns that can be applied to illegal operations to transform
2475  // them into legal ones.
2477  LegalizationPatterns anyOpLegalizerPatterns;
2478 
2479  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2480  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2481 }
2482 
2483 bool OperationLegalizer::isIllegal(Operation *op) const {
2484  return target.isIllegal(op);
2485 }
2486 
2487 LogicalResult OperationLegalizer::legalize(Operation *op) {
2488 #ifndef NDEBUG
2489  const char *logLineComment =
2490  "//===-------------------------------------------===//\n";
2491 
2492  auto &logger = rewriter.getImpl().logger;
2493 #endif
2494 
2495  // Check to see if the operation is ignored and doesn't need to be converted.
2496  bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2497 
2498  LLVM_DEBUG({
2499  logger.getOStream() << "\n";
2500  logger.startLine() << logLineComment;
2501  logger.startLine() << "Legalizing operation : ";
2502  // Do not print the operation name if the operation is ignored. Ignored ops
2503  // may have been erased and should not be accessed. The pointer can be
2504  // printed safely.
2505  if (!isIgnored)
2506  logger.getOStream() << "'" << op->getName() << "' ";
2507  logger.getOStream() << "(" << op << ") {\n";
2508  logger.indent();
2509 
2510  // If the operation has no regions, just print it here.
2511  if (!isIgnored && op->getNumRegions() == 0) {
2512  logger.startLine() << OpWithFlags(op,
2513  OpPrintingFlags().printGenericOpForm())
2514  << "\n";
2515  }
2516  });
2517 
2518  if (isIgnored) {
2519  LLVM_DEBUG({
2520  logSuccess(logger, "operation marked 'ignored' during conversion");
2521  logger.startLine() << logLineComment;
2522  });
2523  return success();
2524  }
2525 
2526  // Check if this operation is legal on the target.
2527  if (auto legalityInfo = target.isLegal(op)) {
2528  LLVM_DEBUG({
2529  logSuccess(
2530  logger, "operation marked legal by the target{0}",
2531  legalityInfo->isRecursivelyLegal
2532  ? "; NOTE: operation is recursively legal; skipping internals"
2533  : "");
2534  logger.startLine() << logLineComment;
2535  });
2536 
2537  // If this operation is recursively legal, mark its children as ignored so
2538  // that we don't consider them for legalization.
2539  if (legalityInfo->isRecursivelyLegal) {
2540  op->walk([&](Operation *nested) {
2541  if (op != nested)
2542  rewriter.getImpl().ignoredOps.insert(nested);
2543  });
2544  }
2545 
2546  return success();
2547  }
2548 
2549  // If the operation is not legal, try to fold it in-place if the folding mode
2550  // is 'BeforePatterns'. 'Never' will skip this.
2551  const ConversionConfig &config = rewriter.getConfig();
2553  if (succeeded(legalizeWithFold(op))) {
2554  LLVM_DEBUG({
2555  logSuccess(logger, "operation was folded");
2556  logger.startLine() << logLineComment;
2557  });
2558  return success();
2559  }
2560  }
2561 
2562  // Otherwise, we need to apply a legalization pattern to this operation.
2563  if (succeeded(legalizeWithPattern(op))) {
2564  LLVM_DEBUG({
2565  logSuccess(logger, "");
2566  logger.startLine() << logLineComment;
2567  });
2568  return success();
2569  }
2570 
2571  // If the operation can't be legalized via patterns, try to fold it in-place
2572  // if the folding mode is 'AfterPatterns'.
2574  if (succeeded(legalizeWithFold(op))) {
2575  LLVM_DEBUG({
2576  logSuccess(logger, "operation was folded");
2577  logger.startLine() << logLineComment;
2578  });
2579  return success();
2580  }
2581  }
2582 
2583  LLVM_DEBUG({
2584  logFailure(logger, "no matched legalization pattern");
2585  logger.startLine() << logLineComment;
2586  });
2587  return failure();
2588 }
2589 
2590 /// Helper function that moves and returns the given object. Also resets the
2591 /// original object, so that it is in a valid, empty state again.
2592 template <typename T>
2593 static T moveAndReset(T &obj) {
2594  T result = std::move(obj);
2595  obj = T();
2596  return result;
2597 }
2598 
2599 LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2600  auto &rewriterImpl = rewriter.getImpl();
2601  LLVM_DEBUG({
2602  rewriterImpl.logger.startLine() << "* Fold {\n";
2603  rewriterImpl.logger.indent();
2604  });
2605 
2606  // Clear pattern state, so that the next pattern application starts with a
2607  // clean slate. (The op/block sets are populated by listener notifications.)
2608  auto cleanup = llvm::make_scope_exit([&]() {
2609  rewriterImpl.patternNewOps.clear();
2610  rewriterImpl.patternModifiedOps.clear();
2611  rewriterImpl.patternInsertedBlocks.clear();
2612  });
2613 
2614  // Upon failure, undo all changes made by the folder.
2615  RewriterState curState = rewriterImpl.getCurrentState();
2616 
2617  // Try to fold the operation.
2618  StringRef opName = op->getName().getStringRef();
2619  SmallVector<Value, 2> replacementValues;
2620  SmallVector<Operation *, 2> newOps;
2621  rewriter.setInsertionPoint(op);
2622  rewriter.startOpModification(op);
2623  if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2624  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2625  rewriter.cancelOpModification(op);
2626  return failure();
2627  }
2628  rewriter.finalizeOpModification(op);
2629 
2630  // An empty list of replacement values indicates that the fold was in-place.
2631  // As the operation changed, a new legalization needs to be attempted.
2632  if (replacementValues.empty())
2633  return legalize(op);
2634 
2635  // Insert a replacement for 'op' with the folded replacement values.
2636  rewriter.replaceOp(op, replacementValues);
2637 
2638  // Recursively legalize any new constant operations.
2639  for (Operation *newOp : newOps) {
2640  if (failed(legalize(newOp))) {
2641  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2642  "failed to legalize generated constant '{0}'",
2643  newOp->getName()));
2644  if (!rewriter.getConfig().allowPatternRollback) {
2645  // Rolling back a folder is like rolling back a pattern.
2646  llvm::report_fatal_error(
2647  "op '" + opName +
2648  "' folder rollback of IR modifications requested");
2649  }
2650  rewriterImpl.resetState(
2651  curState, std::string(op->getName().getStringRef()) + " folder");
2652  return failure();
2653  }
2654  }
2655 
2656  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2657  return success();
2658 }
2659 
2660 /// Report a fatal error indicating that newly produced or modified IR could
2661 /// not be legalized.
2662 static void
2664  const SetVector<Operation *> &newOps,
2665  const SetVector<Operation *> &modifiedOps,
2666  const SetVector<Block *> &insertedBlocks) {
2667  auto newOpNames = llvm::map_range(
2668  newOps, [](Operation *op) { return op->getName().getStringRef(); });
2669  auto modifiedOpNames = llvm::map_range(
2670  modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2671  StringRef detachedBlockStr = "(detached block)";
2672  auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
2673  if (block->getParentOp())
2674  return block->getParentOp()->getName().getStringRef();
2675  return detachedBlockStr;
2676  });
2677  llvm::report_fatal_error(
2678  "pattern '" + pattern.getDebugName() +
2679  "' produced IR that could not be legalized. " + "new ops: {" +
2680  llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
2681  llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
2682  llvm::join(insertedBlockNames, ", ") + "}");
2683 }
2684 
2685 LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2686  auto &rewriterImpl = rewriter.getImpl();
2687  const ConversionConfig &config = rewriter.getConfig();
2688 
2689 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2690  Operation *checkOp;
2691  std::optional<OperationFingerPrint> topLevelFingerPrint;
2692  if (!rewriterImpl.config.allowPatternRollback) {
2693  // The op may be getting erased, so we have to check the parent op.
2694  // (In rare cases, a pattern may even erase the parent op, which will cause
2695  // a crash here. Expensive checks are "best effort".) Skip the check if the
2696  // op does not have a parent op.
2697  if ((checkOp = op->getParentOp())) {
2698  if (!op->getContext()->isMultithreadingEnabled()) {
2699  topLevelFingerPrint = OperationFingerPrint(checkOp);
2700  } else {
2701  // Another thread may be modifying a sibling operation. Therefore, the
2702  // fingerprinting mechanism of the parent op works only in
2703  // single-threaded mode.
2704  LLVM_DEBUG({
2705  rewriterImpl.logger.startLine()
2706  << "WARNING: Multi-threadeding is enabled. Some dialect "
2707  "conversion expensive checks are skipped in multithreading "
2708  "mode!\n";
2709  });
2710  }
2711  }
2712  }
2713 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2714 
2715  // Functor that returns if the given pattern may be applied.
2716  auto canApply = [&](const Pattern &pattern) {
2717  bool canApply = canApplyPattern(op, pattern);
2718  if (canApply && config.listener)
2719  config.listener->notifyPatternBegin(pattern, op);
2720  return canApply;
2721  };
2722 
2723  // Functor that cleans up the rewriter state after a pattern failed to match.
2724  RewriterState curState = rewriterImpl.getCurrentState();
2725  auto onFailure = [&](const Pattern &pattern) {
2726  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2727  if (!rewriterImpl.config.allowPatternRollback) {
2728  // Erase all unresolved materializations.
2729  for (auto op : rewriterImpl.patternMaterializations) {
2730  rewriterImpl.unresolvedMaterializations.erase(op);
2731  op.erase();
2732  }
2733  rewriterImpl.patternMaterializations.clear();
2734 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2735  // Expensive pattern check that can detect API violations.
2736  if (checkOp) {
2737  OperationFingerPrint fingerPrintAfterPattern(checkOp);
2738  if (fingerPrintAfterPattern != *topLevelFingerPrint)
2739  llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2740  "' returned failure but IR did change");
2741  }
2742 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2743  }
2744  rewriterImpl.patternNewOps.clear();
2745  rewriterImpl.patternModifiedOps.clear();
2746  rewriterImpl.patternInsertedBlocks.clear();
2747  LLVM_DEBUG({
2748  logFailure(rewriterImpl.logger, "pattern failed to match");
2749  if (rewriterImpl.config.notifyCallback) {
2751  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2752  << "\" on op:\n"
2753  << *op;
2754  rewriterImpl.config.notifyCallback(diag);
2755  }
2756  });
2757  if (config.listener)
2758  config.listener->notifyPatternEnd(pattern, failure());
2759  rewriterImpl.resetState(curState, pattern.getDebugName());
2760  appliedPatterns.erase(&pattern);
2761  };
2762 
2763  // Functor that performs additional legalization when a pattern is
2764  // successfully applied.
2765  auto onSuccess = [&](const Pattern &pattern) {
2766  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2767  if (!rewriterImpl.config.allowPatternRollback) {
2768  // Eagerly erase unused materializations.
2769  for (auto op : rewriterImpl.patternMaterializations) {
2770  if (op->use_empty()) {
2771  rewriterImpl.unresolvedMaterializations.erase(op);
2772  op.erase();
2773  }
2774  }
2775  rewriterImpl.patternMaterializations.clear();
2776  }
2777  SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2778  SetVector<Operation *> modifiedOps =
2779  moveAndReset(rewriterImpl.patternModifiedOps);
2780  SetVector<Block *> insertedBlocks =
2781  moveAndReset(rewriterImpl.patternInsertedBlocks);
2782  auto result = legalizePatternResult(op, pattern, curState, newOps,
2783  modifiedOps, insertedBlocks);
2784  appliedPatterns.erase(&pattern);
2785  if (failed(result)) {
2786  if (!rewriterImpl.config.allowPatternRollback)
2787  reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
2788  insertedBlocks);
2789  rewriterImpl.resetState(curState, pattern.getDebugName());
2790  }
2791  if (config.listener)
2792  config.listener->notifyPatternEnd(pattern, result);
2793  return result;
2794  };
2795 
2796  // Try to match and rewrite a pattern on this operation.
2797  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2798  onSuccess);
2799 }
2800 
2801 bool OperationLegalizer::canApplyPattern(Operation *op,
2802  const Pattern &pattern) {
2803  LLVM_DEBUG({
2804  auto &os = rewriter.getImpl().logger;
2805  os.getOStream() << "\n";
2806  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2807  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2808  os.getOStream() << ")' {\n";
2809  os.indent();
2810  });
2811 
2812  // Ensure that we don't cycle by not allowing the same pattern to be
2813  // applied twice in the same recursion stack if it is not known to be safe.
2814  if (!pattern.hasBoundedRewriteRecursion() &&
2815  !appliedPatterns.insert(&pattern).second) {
2816  LLVM_DEBUG(
2817  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2818  return false;
2819  }
2820  return true;
2821 }
2822 
2823 LogicalResult OperationLegalizer::legalizePatternResult(
2824  Operation *op, const Pattern &pattern, const RewriterState &curState,
2825  const SetVector<Operation *> &newOps,
2826  const SetVector<Operation *> &modifiedOps,
2827  const SetVector<Block *> &insertedBlocks) {
2828  [[maybe_unused]] auto &impl = rewriter.getImpl();
2829  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2830 
2831 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2832  // Check that the root was either replaced or updated in place.
2833  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2834  auto replacedRoot = [&] {
2835  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2836  };
2837  auto updatedRootInPlace = [&] {
2838  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2839  };
2840  if (!replacedRoot() && !updatedRootInPlace())
2841  llvm::report_fatal_error(
2842  "expected pattern to replace the root operation or modify it in place");
2843 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2844 
2845  // Legalize each of the actions registered during application.
2846  if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
2847  failed(legalizePatternRootUpdates(modifiedOps)) ||
2848  failed(legalizePatternCreatedOperations(newOps))) {
2849  return failure();
2850  }
2851 
2852  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2853  return success();
2854 }
2855 
2856 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2857  Operation *op, const SetVector<Block *> &insertedBlocks,
2858  const SetVector<Operation *> &newOps) {
2859  ConversionPatternRewriterImpl &impl = rewriter.getImpl();
2860  SmallPtrSet<Operation *, 16> alreadyLegalized;
2861 
2862  // If the pattern moved or created any blocks, make sure the types of block
2863  // arguments get legalized.
2864  for (Block *block : insertedBlocks) {
2865  if (impl.erasedBlocks.contains(block))
2866  continue;
2867 
2868  // Only check blocks outside of the current operation.
2869  Operation *parentOp = block->getParentOp();
2870  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2871  continue;
2872 
2873  // If the region of the block has a type converter, try to convert the block
2874  // directly.
2875  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2876  std::optional<TypeConverter::SignatureConversion> conversion =
2877  converter->convertBlockSignature(block);
2878  if (!conversion) {
2879  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2880  "block"));
2881  return failure();
2882  }
2883  impl.applySignatureConversion(block, converter, *conversion);
2884  continue;
2885  }
2886 
2887  // Otherwise, try to legalize the parent operation if it was not generated
2888  // by this pattern. This is because we will attempt to legalize the parent
2889  // operation, and blocks in regions created by this pattern will already be
2890  // legalized later on.
2891  if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
2892  if (failed(legalize(parentOp))) {
2893  LLVM_DEBUG(logFailure(
2894  impl.logger, "operation '{0}'({1}) became illegal after rewrite",
2895  parentOp->getName(), parentOp));
2896  return failure();
2897  }
2898  }
2899  }
2900  return success();
2901 }
2902 
2903 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2904  const SetVector<Operation *> &newOps) {
2905  for (Operation *op : newOps) {
2906  if (failed(legalize(op))) {
2907  LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
2908  "failed to legalize generated operation '{0}'({1})",
2909  op->getName(), op));
2910  return failure();
2911  }
2912  }
2913  return success();
2914 }
2915 
2916 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2917  const SetVector<Operation *> &modifiedOps) {
2918  for (Operation *op : modifiedOps) {
2919  if (failed(legalize(op))) {
2920  LLVM_DEBUG(
2921  logFailure(rewriter.getImpl().logger,
2922  "failed to legalize operation updated in-place '{0}'",
2923  op->getName()));
2924  return failure();
2925  }
2926  }
2927  return success();
2928 }
2929 
2930 //===----------------------------------------------------------------------===//
2931 // Cost Model
2932 //===----------------------------------------------------------------------===//
2933 
2934 void OperationLegalizer::buildLegalizationGraph(
2935  LegalizationPatterns &anyOpLegalizerPatterns,
2936  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2937  // A mapping between an operation and a set of operations that can be used to
2938  // generate it.
2940  // A mapping between an operation and any currently invalid patterns it has.
2942  // A worklist of patterns to consider for legality.
2943  SetVector<const Pattern *> patternWorklist;
2944 
2945  // Build the mapping from operations to the parent ops that may generate them.
2946  applicator.walkAllPatterns([&](const Pattern &pattern) {
2947  std::optional<OperationName> root = pattern.getRootKind();
2948 
2949  // If the pattern has no specific root, we can't analyze the relationship
2950  // between the root op and generated operations. Given that, add all such
2951  // patterns to the legalization set.
2952  if (!root) {
2953  anyOpLegalizerPatterns.push_back(&pattern);
2954  return;
2955  }
2956 
2957  // Skip operations that are always known to be legal.
2958  if (target.getOpAction(*root) == LegalizationAction::Legal)
2959  return;
2960 
2961  // Add this pattern to the invalid set for the root op and record this root
2962  // as a parent for any generated operations.
2963  invalidPatterns[*root].insert(&pattern);
2964  for (auto op : pattern.getGeneratedOps())
2965  parentOps[op].insert(*root);
2966 
2967  // Add this pattern to the worklist.
2968  patternWorklist.insert(&pattern);
2969  });
2970 
2971  // If there are any patterns that don't have a specific root kind, we can't
2972  // make direct assumptions about what operations will never be legalized.
2973  // Note: Technically we could, but it would require an analysis that may
2974  // recurse into itself. It would be better to perform this kind of filtering
2975  // at a higher level than here anyways.
2976  if (!anyOpLegalizerPatterns.empty()) {
2977  for (const Pattern *pattern : patternWorklist)
2978  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2979  return;
2980  }
2981 
2982  while (!patternWorklist.empty()) {
2983  auto *pattern = patternWorklist.pop_back_val();
2984 
2985  // Check to see if any of the generated operations are invalid.
2986  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2987  std::optional<LegalizationAction> action = target.getOpAction(op);
2988  return !legalizerPatterns.count(op) &&
2989  (!action || action == LegalizationAction::Illegal);
2990  }))
2991  continue;
2992 
2993  // Otherwise, if all of the generated operation are valid, this op is now
2994  // legal so add all of the child patterns to the worklist.
2995  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2996  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2997 
2998  // Add any invalid patterns of the parent operations to see if they have now
2999  // become legal.
3000  for (auto op : parentOps[*pattern->getRootKind()])
3001  patternWorklist.set_union(invalidPatterns[op]);
3002  }
3003 }
3004 
3005 void OperationLegalizer::computeLegalizationGraphBenefit(
3006  LegalizationPatterns &anyOpLegalizerPatterns,
3007  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
3008  // The smallest pattern depth, when legalizing an operation.
3009  DenseMap<OperationName, unsigned> minOpPatternDepth;
3010 
3011  // For each operation that is transitively legal, compute a cost for it.
3012  for (auto &opIt : legalizerPatterns)
3013  if (!minOpPatternDepth.count(opIt.first))
3014  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3015  legalizerPatterns);
3016 
3017  // Apply the cost model to the patterns that can match any operation. Those
3018  // with a specific operation type are already resolved when computing the op
3019  // legalization depth.
3020  if (!anyOpLegalizerPatterns.empty())
3021  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3022  legalizerPatterns);
3023 
3024  // Apply a cost model to the pattern applicator. We order patterns first by
3025  // depth then benefit. `legalizerPatterns` contains per-op patterns by
3026  // decreasing benefit.
3027  applicator.applyCostModel([&](const Pattern &pattern) {
3028  ArrayRef<const Pattern *> orderedPatternList;
3029  if (std::optional<OperationName> rootName = pattern.getRootKind())
3030  orderedPatternList = legalizerPatterns[*rootName];
3031  else
3032  orderedPatternList = anyOpLegalizerPatterns;
3033 
3034  // If the pattern is not found, then it was removed and cannot be matched.
3035  auto *it = llvm::find(orderedPatternList, &pattern);
3036  if (it == orderedPatternList.end())
3038 
3039  // Patterns found earlier in the list have higher benefit.
3040  return PatternBenefit(std::distance(it, orderedPatternList.end()));
3041  });
3042 }
3043 
3044 unsigned OperationLegalizer::computeOpLegalizationDepth(
3045  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
3046  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
3047  // Check for existing depth.
3048  auto depthIt = minOpPatternDepth.find(op);
3049  if (depthIt != minOpPatternDepth.end())
3050  return depthIt->second;
3051 
3052  // If a mapping for this operation does not exist, then this operation
3053  // is always legal. Return 0 as the depth for a directly legal operation.
3054  auto opPatternsIt = legalizerPatterns.find(op);
3055  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3056  return 0u;
3057 
3058  // Record this initial depth in case we encounter this op again when
3059  // recursively computing the depth.
3060  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3061 
3062  // Apply the cost model to the operation patterns, and update the minimum
3063  // depth.
3064  unsigned minDepth = applyCostModelToPatterns(
3065  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3066  minOpPatternDepth[op] = minDepth;
3067  return minDepth;
3068 }
3069 
3070 unsigned OperationLegalizer::applyCostModelToPatterns(
3071  LegalizationPatterns &patterns,
3072  DenseMap<OperationName, unsigned> &minOpPatternDepth,
3073  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
3074  unsigned minDepth = std::numeric_limits<unsigned>::max();
3075 
3076  // Compute the depth for each pattern within the set.
3077  SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3078  patternsByDepth.reserve(patterns.size());
3079  for (const Pattern *pattern : patterns) {
3080  unsigned depth = 1;
3081  for (auto generatedOp : pattern->getGeneratedOps()) {
3082  unsigned generatedOpDepth = computeOpLegalizationDepth(
3083  generatedOp, minOpPatternDepth, legalizerPatterns);
3084  depth = std::max(depth, generatedOpDepth + 1);
3085  }
3086  patternsByDepth.emplace_back(pattern, depth);
3087 
3088  // Update the minimum depth of the pattern list.
3089  minDepth = std::min(minDepth, depth);
3090  }
3091 
3092  // If the operation only has one legalization pattern, there is no need to
3093  // sort them.
3094  if (patternsByDepth.size() == 1)
3095  return minDepth;
3096 
3097  // Sort the patterns by those likely to be the most beneficial.
3098  llvm::stable_sort(patternsByDepth,
3099  [](const std::pair<const Pattern *, unsigned> &lhs,
3100  const std::pair<const Pattern *, unsigned> &rhs) {
3101  // First sort by the smaller pattern legalization
3102  // depth.
3103  if (lhs.second != rhs.second)
3104  return lhs.second < rhs.second;
3105 
3106  // Then sort by the larger pattern benefit.
3107  auto lhsBenefit = lhs.first->getBenefit();
3108  auto rhsBenefit = rhs.first->getBenefit();
3109  return lhsBenefit > rhsBenefit;
3110  });
3111 
3112  // Update the legalization pattern to use the new sorted list.
3113  patterns.clear();
3114  for (auto &patternIt : patternsByDepth)
3115  patterns.push_back(patternIt.first);
3116  return minDepth;
3117 }
3118 
3119 //===----------------------------------------------------------------------===//
3120 // Reconcile Unrealized Casts
3121 //===----------------------------------------------------------------------===//
3122 
3123 /// Try to reconcile all given UnrealizedConversionCastOps and store the
3124 /// left-over ops in `remainingCastOps` (if provided). See documentation in
3125 /// DialectConversion.h for more details.
3126 /// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3127 /// algorithm may visit an operand (or user) which is a cast op, but will not
3128 /// try to reconcile it if not in the filtered set.
3129 template <typename RangeT>
3131  RangeT castOps,
3132  function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3133  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3134  // A worklist of cast ops to process.
3135  SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3136 
3137  // Helper function that return the unrealized_conversion_cast op that
3138  // defines all inputs of the given op (in the same order). Return "nullptr"
3139  // if there is no such op.
3140  auto getInputCast =
3141  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3142  if (castOp.getInputs().empty())
3143  return {};
3144  auto inputCastOp =
3145  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3146  if (!inputCastOp)
3147  return {};
3148  if (inputCastOp.getOutputs() != castOp.getInputs())
3149  return {};
3150  return inputCastOp;
3151  };
3152 
3153  // Process ops in the worklist bottom-to-top.
3154  while (!worklist.empty()) {
3155  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3156 
3157  // Traverse the chain of input cast ops to see if an op with the same
3158  // input types can be found.
3159  UnrealizedConversionCastOp nextCast = castOp;
3160  while (nextCast) {
3161  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3162  if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3163  return v.getDefiningOp() == castOp;
3164  })) {
3165  // Ran into a cycle.
3166  break;
3167  }
3168 
3169  // Found a cast where the input types match the output types of the
3170  // matched op. We can directly use those inputs.
3171  castOp.replaceAllUsesWith(nextCast.getInputs());
3172  break;
3173  }
3174  nextCast = getInputCast(nextCast);
3175  }
3176  }
3177 
3178  // A set of all alive cast ops. I.e., ops whose results are (transitively)
3179  // used by an op that is not a cast op.
3180  DenseSet<Operation *> liveOps;
3181 
3182  // Helper function that marks the given op and transitively reachable input
3183  // cast ops as alive.
3184  auto markOpLive = [&](Operation *rootOp) {
3185  SmallVector<Operation *> worklist;
3186  worklist.push_back(rootOp);
3187  while (!worklist.empty()) {
3188  Operation *op = worklist.pop_back_val();
3189  if (liveOps.insert(op).second) {
3190  // Successfully inserted: process reachable input cast ops.
3191  for (Value v : op->getOperands())
3192  if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3193  if (isCastOpOfInterestFn(castOp))
3194  worklist.push_back(castOp);
3195  }
3196  }
3197  };
3198 
3199  // Find all alive cast ops.
3200  for (UnrealizedConversionCastOp op : castOps) {
3201  // The op may have been marked live already as being an operand of another
3202  // live cast op.
3203  if (liveOps.contains(op.getOperation()))
3204  continue;
3205  // If any of the users is not a cast op, mark the current op (and its
3206  // input ops) as live.
3207  if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3208  auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3209  return !castOp || !isCastOpOfInterestFn(castOp);
3210  }))
3211  markOpLive(op);
3212  }
3213 
3214  // Erase all dead cast ops.
3215  for (UnrealizedConversionCastOp op : castOps) {
3216  if (liveOps.contains(op)) {
3217  // Op is alive and was not erased. Add it to the remaining cast ops.
3218  if (remainingCastOps)
3219  remainingCastOps->push_back(op);
3220  continue;
3221  }
3222 
3223  // Op is dead. Erase it.
3224  op->dropAllUses();
3225  op->erase();
3226  }
3227 }
3228 
3231  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3232  // Set of all cast ops for faster lookups.
3234  for (UnrealizedConversionCastOp op : castOps)
3235  castOpSet.insert(op);
3236  reconcileUnrealizedCasts(castOpSet, remainingCastOps);
3237 }
3238 
3240  const DenseSet<UnrealizedConversionCastOp> &castOps,
3241  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3243  llvm::make_range(castOps.begin(), castOps.end()),
3244  [&](UnrealizedConversionCastOp castOp) {
3245  return castOps.contains(castOp);
3246  },
3247  remainingCastOps);
3248 }
3249 
3250 namespace mlir {
3253  &castOps,
3254  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3256  castOps.keys(),
3257  [&](UnrealizedConversionCastOp castOp) {
3258  return castOps.contains(castOp);
3259  },
3260  remainingCastOps);
3261 }
3262 } // namespace mlir
3263 
3264 //===----------------------------------------------------------------------===//
3265 // OperationConverter
3266 //===----------------------------------------------------------------------===//
3267 
3268 namespace {
3269 enum OpConversionMode {
3270  /// In this mode, the conversion will ignore failed conversions to allow
3271  /// illegal operations to co-exist in the IR.
3272  Partial,
3273 
3274  /// In this mode, all operations must be legal for the given target for the
3275  /// conversion to succeed.
3276  Full,
3277 
3278  /// In this mode, operations are analyzed for legality. No actual rewrites are
3279  /// applied to the operations on success.
3280  Analysis,
3281 };
3282 } // namespace
3283 
3284 namespace mlir {
3285 // This class converts operations to a given conversion target via a set of
3286 // rewrite patterns. The conversion behaves differently depending on the
3287 // conversion mode.
3289  explicit OperationConverter(MLIRContext *ctx, const ConversionTarget &target,
3291  const ConversionConfig &config,
3292  OpConversionMode mode)
3293  : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
3294  mode(mode) {}
3295 
3296  /// Converts the given operations to the conversion target.
3297  LogicalResult convertOperations(ArrayRef<Operation *> ops);
3298 
3299 private:
3300  /// Converts an operation with the given rewriter.
3301  LogicalResult convert(Operation *op);
3302 
3303  /// The rewriter to use when converting operations.
3304  ConversionPatternRewriter rewriter;
3305 
3306  /// The legalizer to use when converting operations.
3307  OperationLegalizer opLegalizer;
3308 
3309  /// The conversion mode to use when legalizing operations.
3310  OpConversionMode mode;
3311 };
3312 } // namespace mlir
3313 
3314 LogicalResult OperationConverter::convert(Operation *op) {
3315  const ConversionConfig &config = rewriter.getConfig();
3316 
3317  // Legalize the given operation.
3318  if (failed(opLegalizer.legalize(op))) {
3319  // Handle the case of a failed conversion for each of the different modes.
3320  // Full conversions expect all operations to be converted.
3321  if (mode == OpConversionMode::Full)
3322  return op->emitError()
3323  << "failed to legalize operation '" << op->getName() << "'";
3324  // Partial conversions allow conversions to fail iff the operation was not
3325  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
3326  // set, non-legalizable ops are added to that set.
3327  if (mode == OpConversionMode::Partial) {
3328  if (opLegalizer.isIllegal(op))
3329  return op->emitError()
3330  << "failed to legalize operation '" << op->getName()
3331  << "' that was explicitly marked illegal";
3332  if (config.unlegalizedOps)
3333  config.unlegalizedOps->insert(op);
3334  }
3335  } else if (mode == OpConversionMode::Analysis) {
3336  // Analysis conversions don't fail if any operations fail to legalize,
3337  // they are only interested in the operations that were successfully
3338  // legalized.
3339  if (config.legalizableOps)
3340  config.legalizableOps->insert(op);
3341  }
3342  return success();
3343 }
3344 
3345 static LogicalResult
3347  UnrealizedConversionCastOp op,
3348  const UnresolvedMaterializationInfo &info) {
3349  assert(!op.use_empty() &&
3350  "expected that dead materializations have already been DCE'd");
3351  Operation::operand_range inputOperands = op.getOperands();
3352 
3353  // Try to materialize the conversion.
3354  if (const TypeConverter *converter = info.getConverter()) {
3355  rewriter.setInsertionPoint(op);
3356  SmallVector<Value> newMaterialization;
3357  switch (info.getMaterializationKind()) {
3358  case MaterializationKind::Target:
3359  newMaterialization = converter->materializeTargetConversion(
3360  rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3361  info.getOriginalType());
3362  break;
3363  case MaterializationKind::Source:
3364  assert(op->getNumResults() == 1 && "expected single result");
3365  Value sourceMat = converter->materializeSourceConversion(
3366  rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3367  if (sourceMat)
3368  newMaterialization.push_back(sourceMat);
3369  break;
3370  }
3371  if (!newMaterialization.empty()) {
3372 #ifndef NDEBUG
3373  ValueRange newMaterializationRange(newMaterialization);
3374  assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
3375  "materialization callback produced value of incorrect type");
3376 #endif // NDEBUG
3377  rewriter.replaceOp(op, newMaterialization);
3378  return success();
3379  }
3380  }
3381 
3382  InFlightDiagnostic diag = op->emitError()
3383  << "failed to legalize unresolved materialization "
3384  "from ("
3385  << inputOperands.getTypes() << ") to ("
3386  << op.getResultTypes()
3387  << ") that remained live after conversion";
3388  diag.attachNote(op->getUsers().begin()->getLoc())
3389  << "see existing live user here: " << *op->getUsers().begin();
3390  return failure();
3391 }
3392 
3394  const ConversionTarget &target = opLegalizer.getTarget();
3395 
3396  // Compute the set of operations and blocks to convert.
3397  SmallVector<Operation *> toConvert;
3398  for (auto *op : ops) {
3400  [&](Operation *op) {
3401  toConvert.push_back(op);
3402  // Don't check this operation's children for conversion if the
3403  // operation is recursively legal.
3404  auto legalityInfo = target.isLegal(op);
3405  if (legalityInfo && legalityInfo->isRecursivelyLegal)
3406  return WalkResult::skip();
3407  return WalkResult::advance();
3408  });
3409  }
3410 
3411  // Convert each operation and discard rewrites on failure.
3412  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3413 
3414  for (auto *op : toConvert) {
3415  if (failed(convert(op))) {
3416  // Dialect conversion failed.
3417  if (rewriterImpl.config.allowPatternRollback) {
3418  // Rollback is allowed: restore the original IR.
3419  rewriterImpl.undoRewrites();
3420  } else {
3421  // Rollback is not allowed: apply all modifications that have been
3422  // performed so far.
3423  rewriterImpl.applyRewrites();
3424  }
3425  return failure();
3426  }
3427  }
3428 
3429  // After a successful conversion, apply rewrites.
3430  rewriterImpl.applyRewrites();
3431 
3432  // Reconcile all UnrealizedConversionCastOps that were inserted by the
3433  // dialect conversion frameworks. (Not the ones that were inserted by
3434  // patterns.)
3436  &materializations = rewriterImpl.unresolvedMaterializations;
3437  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
3438  reconcileUnrealizedCasts(materializations, &remainingCastOps);
3439 
3440  // Drop markers.
3441  for (UnrealizedConversionCastOp castOp : remainingCastOps)
3442  castOp->removeAttr(kPureTypeConversionMarker);
3443 
3444  // Try to legalize all unresolved materializations.
3445  if (rewriter.getConfig().buildMaterializations) {
3446  // Use a new rewriter, so the modifications are not tracked for rollback
3447  // purposes etc.
3448  IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
3449  rewriter.getConfig().listener);
3450  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3451  auto it = materializations.find(castOp);
3452  assert(it != materializations.end() && "inconsistent state");
3453  if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
3454  it->second)))
3455  return failure();
3456  }
3457  }
3458 
3459  return success();
3460 }
3461 
3462 //===----------------------------------------------------------------------===//
3463 // Type Conversion
3464 //===----------------------------------------------------------------------===//
3465 
3467  ArrayRef<Type> types) {
3468  assert(!types.empty() && "expected valid types");
3469  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3470  addInputs(types);
3471 }
3472 
3474  assert(!types.empty() &&
3475  "1->0 type remappings don't need to be added explicitly");
3476  argTypes.append(types.begin(), types.end());
3477 }
3478 
3479 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3480  unsigned newInputNo,
3481  unsigned newInputCount) {
3482  assert(!remappedInputs[origInputNo] && "input has already been remapped");
3483  assert(newInputCount != 0 && "expected valid input count");
3484  remappedInputs[origInputNo] =
3485  InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
3486 }
3487 
3489  unsigned origInputNo, ArrayRef<Value> replacements) {
3490  assert(!remappedInputs[origInputNo] && "input has already been remapped");
3491  remappedInputs[origInputNo] = InputMapping{
3492  origInputNo, /*size=*/0,
3493  SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3494 }
3495 
3496 /// Internal implementation of the type conversion.
3497 /// This is used with either a Type or a Value as the first argument.
3498 /// - we can cache the context-free conversions until the last registered
3499 /// context-aware conversion.
3500 /// - we can't cache the result of type conversion happening after context-aware
3501 /// conversions, because the type converter may return different results for the
3502 /// same input type.
3503 LogicalResult
3504 TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3505  SmallVectorImpl<Type> &results) const {
3506  assert(typeOrValue && "expected non-null type");
3507  Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3508  : cast<Type>(typeOrValue);
3509  {
3510  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3511  std::defer_lock);
3513  cacheReadLock.lock();
3514  auto existingIt = cachedDirectConversions.find(t);
3515  if (existingIt != cachedDirectConversions.end()) {
3516  if (existingIt->second)
3517  results.push_back(existingIt->second);
3518  return success(existingIt->second != nullptr);
3519  }
3520  auto multiIt = cachedMultiConversions.find(t);
3521  if (multiIt != cachedMultiConversions.end()) {
3522  results.append(multiIt->second.begin(), multiIt->second.end());
3523  return success();
3524  }
3525  }
3526  // Walk the added converters in reverse order to apply the most recently
3527  // registered first.
3528  size_t currentCount = results.size();
3529 
3530  // We can cache the context-free conversions until the last registered
3531  // context-aware conversion. But only if we're processing a Value right now.
3532  auto isCacheable = [&](int index) {
3533  int numberOfConversionsUntilContextAware =
3534  conversions.size() - 1 - contextAwareTypeConversionsIndex;
3535  return index < numberOfConversionsUntilContextAware;
3536  };
3537 
3538  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3539  std::defer_lock);
3540 
3541  for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3542  const ConversionCallbackFn &converter = indexedConverter.value();
3543  std::optional<LogicalResult> result = converter(typeOrValue, results);
3544  if (!result) {
3545  assert(results.size() == currentCount &&
3546  "failed type conversion should not change results");
3547  continue;
3548  }
3549  if (!isCacheable(indexedConverter.index()))
3550  return success();
3552  cacheWriteLock.lock();
3553  if (!succeeded(*result)) {
3554  assert(results.size() == currentCount &&
3555  "failed type conversion should not change results");
3556  cachedDirectConversions.try_emplace(t, nullptr);
3557  return failure();
3558  }
3559  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3560  if (newTypes.size() == 1)
3561  cachedDirectConversions.try_emplace(t, newTypes.front());
3562  else
3563  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3564  return success();
3565  }
3566  return failure();
3567 }
3568 
3570  SmallVectorImpl<Type> &results) const {
3571  return convertTypeImpl(t, results);
3572 }
3573 
3575  SmallVectorImpl<Type> &results) const {
3576  return convertTypeImpl(v, results);
3577 }
3578 
3580  // Use the multi-type result version to convert the type.
3581  SmallVector<Type, 1> results;
3582  if (failed(convertType(t, results)))
3583  return nullptr;
3584 
3585  // Check to ensure that only one type was produced.
3586  return results.size() == 1 ? results.front() : nullptr;
3587 }
3588 
3590  // Use the multi-type result version to convert the type.
3591  SmallVector<Type, 1> results;
3592  if (failed(convertType(v, results)))
3593  return nullptr;
3594 
3595  // Check to ensure that only one type was produced.
3596  return results.size() == 1 ? results.front() : nullptr;
3597 }
3598 
3599 LogicalResult
3601  SmallVectorImpl<Type> &results) const {
3602  for (Type type : types)
3603  if (failed(convertType(type, results)))
3604  return failure();
3605  return success();
3606 }
3607 
3608 LogicalResult
3610  SmallVectorImpl<Type> &results) const {
3611  for (Value value : values)
3612  if (failed(convertType(value, results)))
3613  return failure();
3614  return success();
3615 }
3616 
3617 bool TypeConverter::isLegal(Type type) const {
3618  return convertType(type) == type;
3619 }
3620 
3621 bool TypeConverter::isLegal(Value value) const {
3622  return convertType(value) == value.getType();
3623 }
3624 
3626  return isLegal(op->getOperands()) && isLegal(op->getResults());
3627 }
3628 
3629 bool TypeConverter::isLegal(Region *region) const {
3630  return llvm::all_of(
3631  *region, [this](Block &block) { return isLegal(block.getArguments()); });
3632 }
3633 
3634 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3635  if (!isLegal(ty.getInputs()))
3636  return false;
3637  if (!isLegal(ty.getResults()))
3638  return false;
3639  return true;
3640 }
3641 
3642 LogicalResult
3644  SignatureConversion &result) const {
3645  // Try to convert the given input type.
3646  SmallVector<Type, 1> convertedTypes;
3647  if (failed(convertType(type, convertedTypes)))
3648  return failure();
3649 
3650  // If this argument is being dropped, there is nothing left to do.
3651  if (convertedTypes.empty())
3652  return success();
3653 
3654  // Otherwise, add the new inputs.
3655  result.addInputs(inputNo, convertedTypes);
3656  return success();
3657 }
3658 LogicalResult
3660  SignatureConversion &result,
3661  unsigned origInputOffset) const {
3662  for (unsigned i = 0, e = types.size(); i != e; ++i)
3663  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3664  return failure();
3665  return success();
3666 }
3667 LogicalResult
3669  SignatureConversion &result) const {
3670  // Try to convert the given input type.
3671  SmallVector<Type, 1> convertedTypes;
3672  if (failed(convertType(value, convertedTypes)))
3673  return failure();
3674 
3675  // If this argument is being dropped, there is nothing left to do.
3676  if (convertedTypes.empty())
3677  return success();
3678 
3679  // Otherwise, add the new inputs.
3680  result.addInputs(inputNo, convertedTypes);
3681  return success();
3682 }
3683 LogicalResult
3685  SignatureConversion &result,
3686  unsigned origInputOffset) const {
3687  for (unsigned i = 0, e = values.size(); i != e; ++i)
3688  if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
3689  return failure();
3690  return success();
3691 }
3692 
3694  Location loc, Type resultType,
3695  ValueRange inputs) const {
3696  for (const SourceMaterializationCallbackFn &fn :
3697  llvm::reverse(sourceMaterializations))
3698  if (Value result = fn(builder, resultType, inputs, loc))
3699  return result;
3700  return nullptr;
3701 }
3702 
3704  Location loc, Type resultType,
3705  ValueRange inputs,
3706  Type originalType) const {
3708  builder, loc, TypeRange(resultType), inputs, originalType);
3709  if (result.empty())
3710  return nullptr;
3711  assert(result.size() == 1 && "expected single result");
3712  return result.front();
3713 }
3714 
3716  OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3717  Type originalType) const {
3718  for (const TargetMaterializationCallbackFn &fn :
3719  llvm::reverse(targetMaterializations)) {
3720  SmallVector<Value> result =
3721  fn(builder, resultTypes, inputs, loc, originalType);
3722  if (result.empty())
3723  continue;
3724  assert(TypeRange(ValueRange(result)) == resultTypes &&
3725  "callback produced incorrect number of values or values with "
3726  "incorrect types");
3727  return result;
3728  }
3729  return {};
3730 }
3731 
3732 std::optional<TypeConverter::SignatureConversion>
3734  SignatureConversion conversion(block->getNumArguments());
3735  if (failed(convertSignatureArgs(block->getArguments(), conversion)))
3736  return std::nullopt;
3737  return conversion;
3738 }
3739 
3740 //===----------------------------------------------------------------------===//
3741 // Type attribute conversion
3742 //===----------------------------------------------------------------------===//
3745  return AttributeConversionResult(attr, resultTag);
3746 }
3747 
3750  return AttributeConversionResult(nullptr, naTag);
3751 }
3752 
3755  return AttributeConversionResult(nullptr, abortTag);
3756 }
3757 
3759  return impl.getInt() == resultTag;
3760 }
3761 
3763  return impl.getInt() == naTag;
3764 }
3765 
3767  return impl.getInt() == abortTag;
3768 }
3769 
3771  assert(hasResult() && "Cannot get result from N/A or abort");
3772  return impl.getPointer();
3773 }
3774 
3775 std::optional<Attribute>
3777  for (const TypeAttributeConversionCallbackFn &fn :
3778  llvm::reverse(typeAttributeConversions)) {
3779  AttributeConversionResult res = fn(type, attr);
3780  if (res.hasResult())
3781  return res.getResult();
3782  if (res.isAbort())
3783  return std::nullopt;
3784  }
3785  return std::nullopt;
3786 }
3787 
3788 //===----------------------------------------------------------------------===//
3789 // FunctionOpInterfaceSignatureConversion
3790 //===----------------------------------------------------------------------===//
3791 
3792 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3793  const TypeConverter &typeConverter,
3794  ConversionPatternRewriter &rewriter) {
3795  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3796  if (!type)
3797  return failure();
3798 
3799  // Convert the original function types.
3800  TypeConverter::SignatureConversion result(type.getNumInputs());
3801  SmallVector<Type, 1> newResults;
3802  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3803  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3804  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3805  typeConverter, &result)))
3806  return failure();
3807 
3808  // Update the function signature in-place.
3809  auto newType = FunctionType::get(rewriter.getContext(),
3810  result.getConvertedTypes(), newResults);
3811 
3812  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3813 
3814  return success();
3815 }
3816 
3817 /// Create a default conversion pattern that rewrites the type signature of a
3818 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3819 /// represent their type.
3820 namespace {
3821 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3822  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3823  MLIRContext *ctx,
3824  const TypeConverter &converter,
3825  PatternBenefit benefit)
3826  : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3827 
3828  LogicalResult
3829  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3830  ConversionPatternRewriter &rewriter) const override {
3831  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3832  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3833  }
3834 };
3835 
3836 struct AnyFunctionOpInterfaceSignatureConversion
3837  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3839 
3840  LogicalResult
3841  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3842  ConversionPatternRewriter &rewriter) const override {
3843  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3844  }
3845 };
3846 } // namespace
3847 
3848 FailureOr<Operation *>
3850  const TypeConverter &converter,
3851  ConversionPatternRewriter &rewriter) {
3852  assert(op && "Invalid op");
3853  Location loc = op->getLoc();
3854  if (converter.isLegal(op))
3855  return rewriter.notifyMatchFailure(loc, "op already legal");
3856 
3857  OperationState newOp(loc, op->getName());
3858  newOp.addOperands(operands);
3859 
3860  SmallVector<Type> newResultTypes;
3861  if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
3862  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3863 
3864  newOp.addTypes(newResultTypes);
3865  newOp.addAttributes(op->getAttrs());
3866  return rewriter.create(newOp);
3867 }
3868 
3870  StringRef functionLikeOpName, RewritePatternSet &patterns,
3871  const TypeConverter &converter, PatternBenefit benefit) {
3872  patterns.add<FunctionOpInterfaceSignatureConversion>(
3873  functionLikeOpName, patterns.getContext(), converter, benefit);
3874 }
3875 
3877  RewritePatternSet &patterns, const TypeConverter &converter,
3878  PatternBenefit benefit) {
3879  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3880  converter, patterns.getContext(), benefit);
3881 }
3882 
3883 //===----------------------------------------------------------------------===//
3884 // ConversionTarget
3885 //===----------------------------------------------------------------------===//
3886 
3888  LegalizationAction action) {
3889  legalOperations[op].action = action;
3890 }
3891 
3893  LegalizationAction action) {
3894  for (StringRef dialect : dialectNames)
3895  legalDialects[dialect] = action;
3896 }
3897 
3899  -> std::optional<LegalizationAction> {
3900  std::optional<LegalizationInfo> info = getOpInfo(op);
3901  return info ? info->action : std::optional<LegalizationAction>();
3902 }
3903 
3905  -> std::optional<LegalOpDetails> {
3906  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3907  if (!info)
3908  return std::nullopt;
3909 
3910  // Returns true if this operation instance is known to be legal.
3911  auto isOpLegal = [&] {
3912  // Handle dynamic legality either with the provided legality function.
3913  if (info->action == LegalizationAction::Dynamic) {
3914  std::optional<bool> result = info->legalityFn(op);
3915  if (result)
3916  return *result;
3917  }
3918 
3919  // Otherwise, the operation is only legal if it was marked 'Legal'.
3920  return info->action == LegalizationAction::Legal;
3921  };
3922  if (!isOpLegal())
3923  return std::nullopt;
3924 
3925  // This operation is legal, compute any additional legality information.
3926  LegalOpDetails legalityDetails;
3927  if (info->isRecursivelyLegal) {
3928  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3929  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3930  legalityDetails.isRecursivelyLegal =
3931  legalityFnIt->second(op).value_or(true);
3932  } else {
3933  legalityDetails.isRecursivelyLegal = true;
3934  }
3935  }
3936  return legalityDetails;
3937 }
3938 
3940  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3941  if (!info)
3942  return false;
3943 
3944  if (info->action == LegalizationAction::Dynamic) {
3945  std::optional<bool> result = info->legalityFn(op);
3946  if (!result)
3947  return false;
3948 
3949  return !(*result);
3950  }
3951 
3952  return info->action == LegalizationAction::Illegal;
3953 }
3954 
3958  if (!oldCallback)
3959  return newCallback;
3960 
3961  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3962  Operation *op) -> std::optional<bool> {
3963  if (std::optional<bool> result = newCl(op))
3964  return *result;
3965 
3966  return oldCl(op);
3967  };
3968  return chain;
3969 }
3970 
3971 void ConversionTarget::setLegalityCallback(
3972  OperationName name, const DynamicLegalityCallbackFn &callback) {
3973  assert(callback && "expected valid legality callback");
3974  auto *infoIt = legalOperations.find(name);
3975  assert(infoIt != legalOperations.end() &&
3976  infoIt->second.action == LegalizationAction::Dynamic &&
3977  "expected operation to already be marked as dynamically legal");
3978  infoIt->second.legalityFn =
3979  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3980 }
3981 
3983  OperationName name, const DynamicLegalityCallbackFn &callback) {
3984  auto *infoIt = legalOperations.find(name);
3985  assert(infoIt != legalOperations.end() &&
3986  infoIt->second.action != LegalizationAction::Illegal &&
3987  "expected operation to already be marked as legal");
3988  infoIt->second.isRecursivelyLegal = true;
3989  if (callback)
3990  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3991  std::move(opRecursiveLegalityFns[name]), callback);
3992  else
3993  opRecursiveLegalityFns.erase(name);
3994 }
3995 
3996 void ConversionTarget::setLegalityCallback(
3997  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3998  assert(callback && "expected valid legality callback");
3999  for (StringRef dialect : dialects)
4000  dialectLegalityFns[dialect] = composeLegalityCallbacks(
4001  std::move(dialectLegalityFns[dialect]), callback);
4002 }
4003 
4004 void ConversionTarget::setLegalityCallback(
4005  const DynamicLegalityCallbackFn &callback) {
4006  assert(callback && "expected valid legality callback");
4007  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
4008 }
4009 
4010 auto ConversionTarget::getOpInfo(OperationName op) const
4011  -> std::optional<LegalizationInfo> {
4012  // Check for info for this specific operation.
4013  const auto *it = legalOperations.find(op);
4014  if (it != legalOperations.end())
4015  return it->second;
4016  // Check for info for the parent dialect.
4017  auto dialectIt = legalDialects.find(op.getDialectNamespace());
4018  if (dialectIt != legalDialects.end()) {
4019  DynamicLegalityCallbackFn callback;
4020  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4021  if (dialectFn != dialectLegalityFns.end())
4022  callback = dialectFn->second;
4023  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
4024  callback};
4025  }
4026  // Otherwise, check if we mark unknown operations as dynamic.
4027  if (unknownLegalityFn)
4028  return LegalizationInfo{LegalizationAction::Dynamic,
4029  /*isRecursivelyLegal=*/false, unknownLegalityFn};
4030  return std::nullopt;
4031 }
4032 
4033 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4034 //===----------------------------------------------------------------------===//
4035 // PDL Configuration
4036 //===----------------------------------------------------------------------===//
4037 
4039  auto &rewriterImpl =
4040  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4041  rewriterImpl.currentTypeConverter = getTypeConverter();
4042 }
4043 
4045  auto &rewriterImpl =
4046  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4047  rewriterImpl.currentTypeConverter = nullptr;
4048 }
4049 
4050 /// Remap the given value using the rewriter and the type converter in the
4051 /// provided config.
4052 static FailureOr<SmallVector<Value>>
4054  SmallVector<Value> mappedValues;
4055  if (failed(rewriter.getRemappedValues(values, mappedValues)))
4056  return failure();
4057  return std::move(mappedValues);
4058 }
4059 
4061  patterns.getPDLPatterns().registerRewriteFunction(
4062  "convertValue",
4063  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4064  auto results = pdllConvertValues(
4065  static_cast<ConversionPatternRewriter &>(rewriter), value);
4066  if (failed(results))
4067  return failure();
4068  return results->front();
4069  });
4070  patterns.getPDLPatterns().registerRewriteFunction(
4071  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
4072  return pdllConvertValues(
4073  static_cast<ConversionPatternRewriter &>(rewriter), values);
4074  });
4075  patterns.getPDLPatterns().registerRewriteFunction(
4076  "convertType",
4077  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4078  auto &rewriterImpl =
4079  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4080  if (const TypeConverter *converter =
4081  rewriterImpl.currentTypeConverter) {
4082  if (Type newType = converter->convertType(type))
4083  return newType;
4084  return failure();
4085  }
4086  return type;
4087  });
4088  patterns.getPDLPatterns().registerRewriteFunction(
4089  "convertTypes",
4090  [](PatternRewriter &rewriter,
4091  TypeRange types) -> FailureOr<SmallVector<Type>> {
4092  auto &rewriterImpl =
4093  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4094  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
4095  if (!converter)
4096  return SmallVector<Type>(types);
4097 
4098  SmallVector<Type> remappedTypes;
4099  if (failed(converter->convertTypes(types, remappedTypes)))
4100  return failure();
4101  return std::move(remappedTypes);
4102  });
4103 }
4104 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
4105 
4106 //===----------------------------------------------------------------------===//
4107 // Op Conversion Entry Points
4108 //===----------------------------------------------------------------------===//
4109 
4110 /// This is the type of Action that is dispatched when a conversion is applied.
4112  : public tracing::ActionImpl<ApplyConversionAction> {
4113 public:
4116  static constexpr StringLiteral tag = "apply-conversion";
4117  static constexpr StringLiteral desc =
4118  "Encapsulate the application of a dialect conversion";
4119 
4120  void print(raw_ostream &os) const override { os << tag; }
4121 };
4122 
4123 static LogicalResult applyConversion(ArrayRef<Operation *> ops,
4124  const ConversionTarget &target,
4127  OpConversionMode mode) {
4128  if (ops.empty())
4129  return success();
4130  MLIRContext *ctx = ops.front()->getContext();
4131  LogicalResult status = success();
4132  SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4134  [&] {
4135  OperationConverter opConverter(ops.front()->getContext(), target,
4136  patterns, config, mode);
4137  status = opConverter.convertOperations(ops);
4138  },
4139  irUnits);
4140  return status;
4141 }
4142 
4143 //===----------------------------------------------------------------------===//
4144 // Partial Conversion
4145 //===----------------------------------------------------------------------===//
4146 
4148  ArrayRef<Operation *> ops, const ConversionTarget &target,
4150  return applyConversion(ops, target, patterns, config,
4151  OpConversionMode::Partial);
4152 }
4153 LogicalResult
4157  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
4158 }
4159 
4160 //===----------------------------------------------------------------------===//
4161 // Full Conversion
4162 //===----------------------------------------------------------------------===//
4163 
4165  const ConversionTarget &target,
4168  return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
4169 }
4171  const ConversionTarget &target,
4174  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
4175 }
4176 
4177 //===----------------------------------------------------------------------===//
4178 // Analysis Conversion
4179 //===----------------------------------------------------------------------===//
4180 
4181 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
4182 /// op is a top-level module op (which is expected to be isolated from above),
4183 /// return that op.
4185  // Check if there is a top-level operation within `ops`. If so, return that
4186  // op.
4187  for (Operation *op : ops) {
4188  if (!op->getParentOp()) {
4189 #ifndef NDEBUG
4190  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
4191  "expected top-level op to be isolated from above");
4192  for (Operation *other : ops)
4193  assert(op->isAncestor(other) &&
4194  "expected ops to have a common ancestor");
4195 #endif // NDEBUG
4196  return op;
4197  }
4198  }
4199 
4200  // No top-level op. Find a common ancestor.
4201  Operation *commonAncestor =
4202  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
4203  for (Operation *op : ops.drop_front()) {
4204  while (!commonAncestor->isProperAncestor(op)) {
4205  commonAncestor =
4207  assert(commonAncestor &&
4208  "expected to find a common isolated from above ancestor");
4209  }
4210  }
4211 
4212  return commonAncestor;
4213 }
4214 
4218 #ifndef NDEBUG
4219  if (config.legalizableOps)
4220  assert(config.legalizableOps->empty() && "expected empty set");
4221 #endif // NDEBUG
4222 
4223  // Clone closted common ancestor that is isolated from above.
4224  Operation *commonAncestor = findCommonAncestor(ops);
4225  IRMapping mapping;
4226  Operation *clonedAncestor = commonAncestor->clone(mapping);
4227  // Compute inverse IR mapping.
4228  DenseMap<Operation *, Operation *> inverseOperationMap;
4229  for (auto &it : mapping.getOperationMap())
4230  inverseOperationMap[it.second] = it.first;
4231 
4232  // Convert the cloned operations. The original IR will remain unchanged.
4233  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4234  ops, [&](Operation *op) { return mapping.lookup(op); });
4235  LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
4236  OpConversionMode::Analysis);
4237 
4238  // Remap `legalizableOps`, so that they point to the original ops and not the
4239  // cloned ops.
4240  if (config.legalizableOps) {
4241  DenseSet<Operation *> originalLegalizableOps;
4242  for (Operation *op : *config.legalizableOps)
4243  originalLegalizableOps.insert(inverseOperationMap[op]);
4244  *config.legalizableOps = std::move(originalLegalizableOps);
4245  }
4246 
4247  // Erase the cloned IR.
4248  clonedAncestor->erase();
4249  return status;
4250 }
4251 
4252 LogicalResult
4256  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
4257 }
static void setInsertionPointAfter(OpBuilder &b, Value value)
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value >> &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
#define DEBUG_TYPE
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps, const SetVector< Block * > &insertedBlocks)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:94
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:308
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
UnitAttr getUnitAttr()
Definition: Builders.cpp:98
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:262
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:25
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
const ConversionConfig & getConfig() const
Return the configuration of the current dialect conversion.
void replaceAllUsesWith(Value from, ValueRange to)
Replace all the uses of from with to.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={}) override
PatternRewriter hook for inlining the ops of a block into another block.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Base class for the conversion patterns.
FailureOr< SmallVector< Value > > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:32
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:274
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:327
Block::iterator getPoint() const
Definition: Builders.h:340
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:337
Block * getBlock() const
Definition: Builders.h:339
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:473
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:616
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:421
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getTypes() const
Definition: ValueRange.cpp:28
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:719
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:607
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
result_range getResults()
Definition: Operation.h:415
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:896
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:140
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:632
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:622
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
CRTP Implementation of an action.
Definition: Action.h:76
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
@ AfterPatterns
Only attempt to fold not legal operations after applying patterns.
@ BeforePatterns
Only attempt to fold not legal operations before applying patterns.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
void reconcileUnrealizedCasts(const DenseSet< UnrealizedConversionCastOp > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
bool attachDebugMaterializationKind
If set to "true", the materialization kind ("source" or "target") will be attached to "builtin....
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:308
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:298
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:379
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:403
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:387
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:376
This struct represents a range of new types or a range of values that remaps an existing signature in...
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config)
void replaceAllUsesWith(Value from, ValueRange to, const TypeConverter *converter)
Replace the uses of the given value with the given values.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
SetVector< Block * > patternInsertedBlocks
A set of blocks that were inserted (newly-created blocks or moved blocks) by the current pattern.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
void replaceOp(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the results of the given operation with the given values and erase the operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.