MLIR  22.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Iterators.h"
17 #include "mlir/IR/Operation.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "llvm/Support/SaveAndRestore.h"
26 #include "llvm/Support/ScopedPrinter.h"
27 #include <optional>
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 #define DEBUG_TYPE "dialect-conversion"
33 
34 /// A utility function to log a successful result for the given reason.
35 template <typename... Args>
36 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
37  LLVM_DEBUG({
38  os.unindent();
39  os.startLine() << "} -> SUCCESS";
40  if (!fmt.empty())
41  os.getOStream() << " : "
42  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
43  os.getOStream() << "\n";
44  });
45 }
46 
47 /// A utility function to log a failure result for the given reason.
48 template <typename... Args>
49 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
50  LLVM_DEBUG({
51  os.unindent();
52  os.startLine() << "} -> FAILURE : "
53  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
54  << "\n";
55  });
56 }
57 
58 /// Helper function that computes an insertion point where the given value is
59 /// defined and can be used without a dominance violation.
61  Block *insertBlock = value.getParentBlock();
62  Block::iterator insertPt = insertBlock->begin();
63  if (OpResult inputRes = dyn_cast<OpResult>(value))
64  insertPt = ++inputRes.getOwner()->getIterator();
65  return OpBuilder::InsertPoint(insertBlock, insertPt);
66 }
67 
68 /// Helper function that computes an insertion point where the given values are
69 /// defined and can be used without a dominance violation.
71  assert(!vals.empty() && "expected at least one value");
72  DominanceInfo domInfo;
73  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
74  for (Value v : vals.drop_front()) {
75  // Choose the "later" insertion point.
77  if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
78  nextPt.getPoint())) {
79  // pt is before nextPt => choose nextPt.
80  pt = nextPt;
81  } else {
82 #ifndef NDEBUG
83  // nextPt should be before pt => choose pt.
84  // If pt, nextPt are no dominance relationship, then there is no valid
85  // insertion point at which all given values are defined.
86  bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
87  pt.getBlock(), pt.getPoint());
88  assert(dom && "unable to find valid insertion point");
89 #endif // NDEBUG
90  }
91  }
92  return pt;
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // ConversionValueMapping
97 //===----------------------------------------------------------------------===//
98 
99 /// A vector of SSA values, optimized for the most common case of a single
100 /// value.
101 using ValueVector = SmallVector<Value, 1>;
102 
103 namespace {
104 
105 /// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
106 struct ValueVectorMapInfo {
107  static ValueVector getEmptyKey() { return ValueVector{Value()}; }
108  static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
109  static ::llvm::hash_code getHashValue(const ValueVector &val) {
110  return ::llvm::hash_combine_range(val);
111  }
112  static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
113  return LHS == RHS;
114  }
115 };
116 
117 /// This class wraps a IRMapping to provide recursive lookup
118 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
119 struct ConversionValueMapping {
120  /// Return "true" if an SSA value is mapped to the given value. May return
121  /// false positives.
122  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
123 
124  /// Lookup a value in the mapping.
125  ValueVector lookup(const ValueVector &from) const;
126 
127  template <typename T>
128  struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
129 
130  /// Map a value vector to the one provided.
131  template <typename OldVal, typename NewVal>
132  std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
133  map(OldVal &&oldVal, NewVal &&newVal) {
134  LLVM_DEBUG({
135  ValueVector next(newVal);
136  while (true) {
137  assert(next != oldVal && "inserting cyclic mapping");
138  auto it = mapping.find(next);
139  if (it == mapping.end())
140  break;
141  next = it->second;
142  }
143  });
144  mappedTo.insert_range(newVal);
145 
146  mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
147  }
148 
149  /// Map a value vector or single value to the one provided.
150  template <typename OldVal, typename NewVal>
151  std::enable_if_t<!IsValueVector<OldVal>::value ||
152  !IsValueVector<NewVal>::value>
153  map(OldVal &&oldVal, NewVal &&newVal) {
154  if constexpr (IsValueVector<OldVal>{}) {
155  map(std::forward<OldVal>(oldVal), ValueVector{newVal});
156  } else if constexpr (IsValueVector<NewVal>{}) {
157  map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
158  } else {
159  map(ValueVector{oldVal}, ValueVector{newVal});
160  }
161  }
162 
163  void map(Value oldVal, SmallVector<Value> &&newVal) {
164  map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
165  }
166 
167  /// Drop the last mapping for the given values.
168  void erase(const ValueVector &value) { mapping.erase(value); }
169 
170 private:
171  /// Current value mappings.
173 
174  /// All SSA values that are mapped to. May contain false positives.
175  DenseSet<Value> mappedTo;
176 };
177 } // namespace
178 
179 /// Marker attribute for pure type conversions. I.e., mappings whose only
180 /// purpose is to resolve a type mismatch. (In contrast, mappings that point to
181 /// the replacement values of a "replaceOp" call, etc., are not pure type
182 /// conversions.)
183 static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
184 
185 /// Return the operation that defines all values in the vector. Return nullptr
186 /// if the values are not defined by the same operation.
187 static Operation *getCommonDefiningOp(const ValueVector &values) {
188  assert(!values.empty() && "expected non-empty value vector");
189  Operation *op = values.front().getDefiningOp();
190  for (Value v : llvm::drop_begin(values)) {
191  if (v.getDefiningOp() != op)
192  return nullptr;
193  }
194  return op;
195 }
196 
197 /// A vector of values is a pure type conversion if all values are defined by
198 /// the same operation and the operation has the `kPureTypeConversionMarker`
199 /// attribute.
200 static bool isPureTypeConversion(const ValueVector &values) {
201  assert(!values.empty() && "expected non-empty value vector");
202  Operation *op = getCommonDefiningOp(values);
203  return op && op->hasAttr(kPureTypeConversionMarker);
204 }
205 
206 ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
207  auto it = mapping.find(from);
208  if (it == mapping.end()) {
209  // No mapping found: The lookup stops here.
210  return {};
211  }
212  return it->second;
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // Rewriter and Translation State
217 //===----------------------------------------------------------------------===//
218 namespace {
219 /// This class contains a snapshot of the current conversion rewriter state.
220 /// This is useful when saving and undoing a set of rewrites.
221 struct RewriterState {
222  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
223  unsigned numReplacedOps)
224  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
225  numReplacedOps(numReplacedOps) {}
226 
227  /// The current number of rewrites performed.
228  unsigned numRewrites;
229 
230  /// The current number of ignored operations.
231  unsigned numIgnoredOperations;
232 
233  /// The current number of replaced ops that are scheduled for erasure.
234  unsigned numReplacedOps;
235 };
236 
237 //===----------------------------------------------------------------------===//
238 // IR rewrites
239 //===----------------------------------------------------------------------===//
240 
241 static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
242 
243 /// Notify the listener that the given block and its contents are being erased.
244 static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
245  for (Operation &op : b)
246  notifyIRErased(listener, op);
247  listener->notifyBlockErased(&b);
248 }
249 
250 /// Notify the listener that the given operation and its contents are being
251 /// erased.
252 static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
253  for (Region &r : op.getRegions()) {
254  for (Block &b : r) {
255  notifyIRErased(listener, b);
256  }
257  }
258  listener->notifyOperationErased(&op);
259 }
260 
261 /// An IR rewrite that can be committed (upon success) or rolled back (upon
262 /// failure).
263 ///
264 /// The dialect conversion keeps track of IR modifications (requested by the
265 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
266 /// are directly applied to the IR as the rewriter API is used, some are applied
267 /// partially, and some are delayed until the `IRRewrite` objects are committed.
268 class IRRewrite {
269 public:
270  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
271  /// Enum values are ordered, so that they can be used in `classof`: first all
272  /// block rewrites, then all operation rewrites.
273  enum class Kind {
274  // Block rewrites
275  CreateBlock,
276  EraseBlock,
277  InlineBlock,
278  MoveBlock,
279  BlockTypeConversion,
280  // Operation rewrites
281  MoveOperation,
282  ModifyOperation,
283  ReplaceOperation,
284  CreateOperation,
285  UnresolvedMaterialization,
286  // Value rewrites
287  ReplaceValue
288  };
289 
290  virtual ~IRRewrite() = default;
291 
292  /// Roll back the rewrite. Operations may be erased during rollback.
293  virtual void rollback() = 0;
294 
295  /// Commit the rewrite. At this point, it is certain that the dialect
296  /// conversion will succeed. All IR modifications, except for operation/block
297  /// erasure, must be performed through the given rewriter.
298  ///
299  /// Instead of erasing operations/blocks, they should merely be unlinked
300  /// commit phase and finally be erased during the cleanup phase. This is
301  /// because internal dialect conversion state (such as `mapping`) may still
302  /// be using them.
303  ///
304  /// Any IR modification that was already performed before the commit phase
305  /// (e.g., insertion of an op) must be communicated to the listener that may
306  /// be attached to the given rewriter.
307  virtual void commit(RewriterBase &rewriter) {}
308 
309  /// Cleanup operations/blocks. Cleanup is called after commit.
310  virtual void cleanup(RewriterBase &rewriter) {}
311 
312  Kind getKind() const { return kind; }
313 
314  static bool classof(const IRRewrite *rewrite) { return true; }
315 
316 protected:
317  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
318  : kind(kind), rewriterImpl(rewriterImpl) {}
319 
320  const ConversionConfig &getConfig() const;
321 
322  const Kind kind;
323  ConversionPatternRewriterImpl &rewriterImpl;
324 };
325 
326 /// A block rewrite.
327 class BlockRewrite : public IRRewrite {
328 public:
329  /// Return the block that this rewrite operates on.
330  Block *getBlock() const { return block; }
331 
332  static bool classof(const IRRewrite *rewrite) {
333  return rewrite->getKind() >= Kind::CreateBlock &&
334  rewrite->getKind() <= Kind::BlockTypeConversion;
335  }
336 
337 protected:
338  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
339  Block *block)
340  : IRRewrite(kind, rewriterImpl), block(block) {}
341 
342  // The block that this rewrite operates on.
343  Block *block;
344 };
345 
346 /// A value rewrite.
347 class ValueRewrite : public IRRewrite {
348 public:
349  /// Return the value that this rewrite operates on.
350  Value getValue() const { return value; }
351 
352  static bool classof(const IRRewrite *rewrite) {
353  return rewrite->getKind() == Kind::ReplaceValue;
354  }
355 
356 protected:
357  ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
358  Value value)
359  : IRRewrite(kind, rewriterImpl), value(value) {}
360 
361  // The value that this rewrite operates on.
362  Value value;
363 };
364 
365 /// Creation of a block. Block creations are immediately reflected in the IR.
366 /// There is no extra work to commit the rewrite. During rollback, the newly
367 /// created block is erased.
368 class CreateBlockRewrite : public BlockRewrite {
369 public:
370  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
371  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
372 
373  static bool classof(const IRRewrite *rewrite) {
374  return rewrite->getKind() == Kind::CreateBlock;
375  }
376 
377  void commit(RewriterBase &rewriter) override {
378  // The block was already created and inserted. Just inform the listener.
379  if (auto *listener = rewriter.getListener())
380  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
381  }
382 
383  void rollback() override {
384  // Unlink all of the operations within this block, they will be deleted
385  // separately.
386  auto &blockOps = block->getOperations();
387  while (!blockOps.empty())
388  blockOps.remove(blockOps.begin());
389  block->dropAllUses();
390  if (block->getParent())
391  block->erase();
392  else
393  delete block;
394  }
395 };
396 
397 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
398 /// blocks are immediately unlinked, but only erased during cleanup. This makes
399 /// it easier to rollback a block erasure: the block is simply inserted into its
400 /// original location.
401 class EraseBlockRewrite : public BlockRewrite {
402 public:
403  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
404  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
405  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
406 
407  static bool classof(const IRRewrite *rewrite) {
408  return rewrite->getKind() == Kind::EraseBlock;
409  }
410 
411  ~EraseBlockRewrite() override {
412  assert(!block &&
413  "rewrite was neither rolled back nor committed/cleaned up");
414  }
415 
416  void rollback() override {
417  // The block (owned by this rewrite) was not actually erased yet. It was
418  // just unlinked. Put it back into its original position.
419  assert(block && "expected block");
420  auto &blockList = region->getBlocks();
421  Region::iterator before = insertBeforeBlock
422  ? Region::iterator(insertBeforeBlock)
423  : blockList.end();
424  blockList.insert(before, block);
425  block = nullptr;
426  }
427 
428  void commit(RewriterBase &rewriter) override {
429  assert(block && "expected block");
430 
431  // Notify the listener that the block and its contents are being erased.
432  if (auto *listener =
433  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
434  notifyIRErased(listener, *block);
435  }
436 
437  void cleanup(RewriterBase &rewriter) override {
438  // Erase the contents of the block.
439  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
440  rewriter.eraseOp(&op);
441  assert(block->empty() && "expected empty block");
442 
443  // Erase the block.
444  block->dropAllDefinedValueUses();
445  delete block;
446  block = nullptr;
447  }
448 
449 private:
450  // The region in which this block was previously contained.
451  Region *region;
452 
453  // The original successor of this block before it was unlinked. "nullptr" if
454  // this block was the only block in the region.
455  Block *insertBeforeBlock;
456 };
457 
458 /// Inlining of a block. This rewrite is immediately reflected in the IR.
459 /// Note: This rewrite represents only the inlining of the operations. The
460 /// erasure of the inlined block is a separate rewrite.
461 class InlineBlockRewrite : public BlockRewrite {
462 public:
463  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
464  Block *sourceBlock, Block::iterator before)
465  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
466  sourceBlock(sourceBlock),
467  firstInlinedInst(sourceBlock->empty() ? nullptr
468  : &sourceBlock->front()),
469  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
470  // If a listener is attached to the dialect conversion, ops must be moved
471  // one-by-one. When they are moved in bulk, notifications cannot be sent
472  // because the ops that used to be in the source block at the time of the
473  // inlining (before the "commit" phase) are unknown at the time when
474  // notifications are sent (which is during the "commit" phase).
475  assert(!getConfig().listener &&
476  "InlineBlockRewrite not supported if listener is attached");
477  }
478 
479  static bool classof(const IRRewrite *rewrite) {
480  return rewrite->getKind() == Kind::InlineBlock;
481  }
482 
483  void rollback() override {
484  // Put the operations from the destination block (owned by the rewrite)
485  // back into the source block.
486  if (firstInlinedInst) {
487  assert(lastInlinedInst && "expected operation");
488  sourceBlock->getOperations().splice(sourceBlock->begin(),
489  block->getOperations(),
490  Block::iterator(firstInlinedInst),
491  ++Block::iterator(lastInlinedInst));
492  }
493  }
494 
495 private:
496  // The block that originally contained the operations.
497  Block *sourceBlock;
498 
499  // The first inlined operation.
500  Operation *firstInlinedInst;
501 
502  // The last inlined operation.
503  Operation *lastInlinedInst;
504 };
505 
506 /// Moving of a block. This rewrite is immediately reflected in the IR.
507 class MoveBlockRewrite : public BlockRewrite {
508 public:
509  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
510  Region *previousRegion, Region::iterator previousIt)
511  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
512  region(previousRegion),
513  insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
514  : &*previousIt) {}
515 
516  static bool classof(const IRRewrite *rewrite) {
517  return rewrite->getKind() == Kind::MoveBlock;
518  }
519 
520  void commit(RewriterBase &rewriter) override {
521  // The block was already moved. Just inform the listener.
522  if (auto *listener = rewriter.getListener()) {
523  // Note: `previousIt` cannot be passed because this is a delayed
524  // notification and iterators into past IR state cannot be represented.
525  listener->notifyBlockInserted(block, /*previous=*/region,
526  /*previousIt=*/{});
527  }
528  }
529 
530  void rollback() override {
531  // Move the block back to its original position.
532  Region::iterator before =
533  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
534  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
535  }
536 
537 private:
538  // The region in which this block was previously contained.
539  Region *region;
540 
541  // The original successor of this block before it was moved. "nullptr" if
542  // this block was the only block in the region.
543  Block *insertBeforeBlock;
544 };
545 
546 /// Block type conversion. This rewrite is partially reflected in the IR.
547 class BlockTypeConversionRewrite : public BlockRewrite {
548 public:
549  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
550  Block *origBlock, Block *newBlock)
551  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
552  newBlock(newBlock) {}
553 
554  static bool classof(const IRRewrite *rewrite) {
555  return rewrite->getKind() == Kind::BlockTypeConversion;
556  }
557 
558  Block *getOrigBlock() const { return block; }
559 
560  Block *getNewBlock() const { return newBlock; }
561 
562  void commit(RewriterBase &rewriter) override;
563 
564  void rollback() override;
565 
566 private:
567  /// The new block that was created as part of this signature conversion.
568  Block *newBlock;
569 };
570 
571 /// Replacing a value. This rewrite is not immediately reflected in the
572 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
573 /// until the rewrite is committed.
574 class ReplaceValueRewrite : public ValueRewrite {
575 public:
576  ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
577  const TypeConverter *converter)
578  : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
579  converter(converter) {}
580 
581  static bool classof(const IRRewrite *rewrite) {
582  return rewrite->getKind() == Kind::ReplaceValue;
583  }
584 
585  void commit(RewriterBase &rewriter) override;
586 
587  void rollback() override;
588 
589 private:
590  /// The current type converter when the value was replaced.
591  const TypeConverter *converter;
592 };
593 
594 /// An operation rewrite.
595 class OperationRewrite : public IRRewrite {
596 public:
597  /// Return the operation that this rewrite operates on.
598  Operation *getOperation() const { return op; }
599 
600  static bool classof(const IRRewrite *rewrite) {
601  return rewrite->getKind() >= Kind::MoveOperation &&
602  rewrite->getKind() <= Kind::UnresolvedMaterialization;
603  }
604 
605 protected:
606  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
607  Operation *op)
608  : IRRewrite(kind, rewriterImpl), op(op) {}
609 
610  // The operation that this rewrite operates on.
611  Operation *op;
612 };
613 
614 /// Moving of an operation. This rewrite is immediately reflected in the IR.
615 class MoveOperationRewrite : public OperationRewrite {
616 public:
617  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
618  Operation *op, OpBuilder::InsertPoint previous)
619  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
620  block(previous.getBlock()),
621  insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
622  ? nullptr
623  : &*previous.getPoint()) {}
624 
625  static bool classof(const IRRewrite *rewrite) {
626  return rewrite->getKind() == Kind::MoveOperation;
627  }
628 
629  void commit(RewriterBase &rewriter) override {
630  // The operation was already moved. Just inform the listener.
631  if (auto *listener = rewriter.getListener()) {
632  // Note: `previousIt` cannot be passed because this is a delayed
633  // notification and iterators into past IR state cannot be represented.
634  listener->notifyOperationInserted(
635  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
636  /*insertPt=*/{}));
637  }
638  }
639 
640  void rollback() override {
641  // Move the operation back to its original position.
642  Block::iterator before =
643  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
644  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
645  }
646 
647 private:
648  // The block in which this operation was previously contained.
649  Block *block;
650 
651  // The original successor of this operation before it was moved. "nullptr"
652  // if this operation was the only operation in the region.
653  Operation *insertBeforeOp;
654 };
655 
656 /// In-place modification of an op. This rewrite is immediately reflected in
657 /// the IR. The previous state of the operation is stored in this object.
658 class ModifyOperationRewrite : public OperationRewrite {
659 public:
660  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
661  Operation *op)
662  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
663  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
664  operands(op->operand_begin(), op->operand_end()),
665  successors(op->successor_begin(), op->successor_end()) {
666  if (OpaqueProperties prop = op->getPropertiesStorage()) {
667  // Make a copy of the properties.
668  propertiesStorage = operator new(op->getPropertiesStorageSize());
669  OpaqueProperties propCopy(propertiesStorage);
670  name.initOpProperties(propCopy, /*init=*/prop);
671  }
672  }
673 
674  static bool classof(const IRRewrite *rewrite) {
675  return rewrite->getKind() == Kind::ModifyOperation;
676  }
677 
678  ~ModifyOperationRewrite() override {
679  assert(!propertiesStorage &&
680  "rewrite was neither committed nor rolled back");
681  }
682 
683  void commit(RewriterBase &rewriter) override {
684  // Notify the listener that the operation was modified in-place.
685  if (auto *listener =
686  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
687  listener->notifyOperationModified(op);
688 
689  if (propertiesStorage) {
690  OpaqueProperties propCopy(propertiesStorage);
691  // Note: The operation may have been erased in the mean time, so
692  // OperationName must be stored in this object.
693  name.destroyOpProperties(propCopy);
694  operator delete(propertiesStorage);
695  propertiesStorage = nullptr;
696  }
697  }
698 
699  void rollback() override {
700  op->setLoc(loc);
701  op->setAttrs(attrs);
702  op->setOperands(operands);
703  for (const auto &it : llvm::enumerate(successors))
704  op->setSuccessor(it.value(), it.index());
705  if (propertiesStorage) {
706  OpaqueProperties propCopy(propertiesStorage);
707  op->copyProperties(propCopy);
708  name.destroyOpProperties(propCopy);
709  operator delete(propertiesStorage);
710  propertiesStorage = nullptr;
711  }
712  }
713 
714 private:
715  OperationName name;
716  LocationAttr loc;
717  DictionaryAttr attrs;
718  SmallVector<Value, 8> operands;
719  SmallVector<Block *, 2> successors;
720  void *propertiesStorage = nullptr;
721 };
722 
723 /// Replacing an operation. Erasing an operation is treated as a special case
724 /// with "null" replacements. This rewrite is not immediately reflected in the
725 /// IR. An internal IR mapping is updated, but values are not replaced and the
726 /// original op is not erased until the rewrite is committed.
727 class ReplaceOperationRewrite : public OperationRewrite {
728 public:
729  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
730  Operation *op, const TypeConverter *converter)
731  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
732  converter(converter) {}
733 
734  static bool classof(const IRRewrite *rewrite) {
735  return rewrite->getKind() == Kind::ReplaceOperation;
736  }
737 
738  void commit(RewriterBase &rewriter) override;
739 
740  void rollback() override;
741 
742  void cleanup(RewriterBase &rewriter) override;
743 
744 private:
745  /// An optional type converter that can be used to materialize conversions
746  /// between the new and old values if necessary.
747  const TypeConverter *converter;
748 };
749 
750 class CreateOperationRewrite : public OperationRewrite {
751 public:
752  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
753  Operation *op)
754  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
755 
756  static bool classof(const IRRewrite *rewrite) {
757  return rewrite->getKind() == Kind::CreateOperation;
758  }
759 
760  void commit(RewriterBase &rewriter) override {
761  // The operation was already created and inserted. Just inform the listener.
762  if (auto *listener = rewriter.getListener())
763  listener->notifyOperationInserted(op, /*previous=*/{});
764  }
765 
766  void rollback() override;
767 };
768 
769 /// The type of materialization.
770 enum MaterializationKind {
771  /// This materialization materializes a conversion from an illegal type to a
772  /// legal one.
773  Target,
774 
775  /// This materialization materializes a conversion from a legal type back to
776  /// an illegal one.
777  Source
778 };
779 
780 /// Helper class that stores metadata about an unresolved materialization.
781 class UnresolvedMaterializationInfo {
782 public:
783  UnresolvedMaterializationInfo() = default;
784  UnresolvedMaterializationInfo(const TypeConverter *converter,
785  MaterializationKind kind, Type originalType)
786  : converterAndKind(converter, kind), originalType(originalType) {}
787 
788  /// Return the type converter of this materialization (which may be null).
789  const TypeConverter *getConverter() const {
790  return converterAndKind.getPointer();
791  }
792 
793  /// Return the kind of this materialization.
794  MaterializationKind getMaterializationKind() const {
795  return converterAndKind.getInt();
796  }
797 
798  /// Return the original type of the SSA value.
799  Type getOriginalType() const { return originalType; }
800 
801 private:
802  /// The corresponding type converter to use when resolving this
803  /// materialization, and the kind of this materialization.
804  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
805  converterAndKind;
806 
807  /// The original type of the SSA value. Only used for target
808  /// materializations.
809  Type originalType;
810 };
811 
812 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
813 /// op. Unresolved materializations fold away or are replaced with
814 /// source/target materializations at the end of the dialect conversion.
815 class UnresolvedMaterializationRewrite : public OperationRewrite {
816 public:
817  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
818  UnrealizedConversionCastOp op,
819  ValueVector mappedValues)
820  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
821  mappedValues(std::move(mappedValues)) {}
822 
823  static bool classof(const IRRewrite *rewrite) {
824  return rewrite->getKind() == Kind::UnresolvedMaterialization;
825  }
826 
827  void rollback() override;
828 
829  UnrealizedConversionCastOp getOperation() const {
830  return cast<UnrealizedConversionCastOp>(op);
831  }
832 
833 private:
834  /// The values in the conversion value mapping that are being replaced by the
835  /// results of this unresolved materialization.
836  ValueVector mappedValues;
837 };
838 } // namespace
839 
840 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
841 /// Return "true" if there is an operation rewrite that matches the specified
842 /// rewrite type and operation among the given rewrites.
843 template <typename RewriteTy, typename R>
844 static bool hasRewrite(R &&rewrites, Operation *op) {
845  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
846  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
847  return rewriteTy && rewriteTy->getOperation() == op;
848  });
849 }
850 
851 /// Return "true" if there is a block rewrite that matches the specified
852 /// rewrite type and block among the given rewrites.
853 template <typename RewriteTy, typename R>
854 static bool hasRewrite(R &&rewrites, Block *block) {
855  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
856  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
857  return rewriteTy && rewriteTy->getBlock() == block;
858  });
859 }
860 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
861 
862 //===----------------------------------------------------------------------===//
863 // ConversionPatternRewriterImpl
864 //===----------------------------------------------------------------------===//
865 namespace mlir {
866 namespace detail {
869  const ConversionConfig &config)
870  : rewriter(rewriter), config(config),
871  notifyingRewriter(rewriter.getContext(), config.listener) {}
872 
873  //===--------------------------------------------------------------------===//
874  // State Management
875  //===--------------------------------------------------------------------===//
876 
877  /// Return the current state of the rewriter.
878  RewriterState getCurrentState();
879 
880  /// Apply all requested operation rewrites. This method is invoked when the
881  /// conversion process succeeds.
882  void applyRewrites();
883 
884  /// Reset the state of the rewriter to a previously saved point. Optionally,
885  /// the name of the pattern that triggered the rollback can specified for
886  /// debugging purposes.
887  void resetState(RewriterState state, StringRef patternName = "");
888 
889  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
890  /// failure.
891  template <typename RewriteTy, typename... Args>
892  void appendRewrite(Args &&...args) {
893  assert(config.allowPatternRollback && "appending rewrites is not allowed");
894  rewrites.push_back(
895  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
896  }
897 
898  /// Undo the rewrites (motions, splits) one by one in reverse order until
899  /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
900  /// that triggered the rollback can specified for debugging purposes.
901  void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
902 
903  /// Remap the given values to those with potentially different types. Returns
904  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
905  /// is the tag used when describing a value within a diagnostic, e.g.
906  /// "operand".
907  LogicalResult remapValues(StringRef valueDiagTag,
908  std::optional<Location> inputLoc, ValueRange values,
909  SmallVector<ValueVector> &remapped);
910 
911  /// Return "true" if the given operation is ignored, and does not need to be
912  /// converted.
913  bool isOpIgnored(Operation *op) const;
914 
915  /// Return "true" if the given operation was replaced or erased.
916  bool wasOpReplaced(Operation *op) const;
917 
918  /// Lookup the most recently mapped values with the desired types in the
919  /// mapping, taking into account only replacements. Perform a best-effort
920  /// search for existing materializations with the desired types.
921  ///
922  /// If `skipPureTypeConversions` is "true", materializations that are pure
923  /// type conversions are not considered.
924  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
925  bool skipPureTypeConversions = false) const;
926 
927  /// Lookup the given value within the map, or return an empty vector if the
928  /// value is not mapped. If it is mapped, this follows the same behavior
929  /// as `lookupOrDefault`.
930  ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
931 
932  //===--------------------------------------------------------------------===//
933  // IR Rewrites / Type Conversion
934  //===--------------------------------------------------------------------===//
935 
936  /// Convert the types of block arguments within the given region.
937  FailureOr<Block *>
938  convertRegionTypes(Region *region, const TypeConverter &converter,
939  TypeConverter::SignatureConversion *entryConversion);
940 
941  /// Apply the given signature conversion on the given block. The new block
942  /// containing the updated signature is returned. If no conversions were
943  /// necessary, e.g. if the block has no arguments, `block` is returned.
944  /// `converter` is used to generate any necessary cast operations that
945  /// translate between the origin argument types and those specified in the
946  /// signature conversion.
947  Block *applySignatureConversion(
948  Block *block, const TypeConverter *converter,
949  TypeConverter::SignatureConversion &signatureConversion);
950 
951  /// Replace the results of the given operation with the given values and
952  /// erase the operation.
953  ///
954  /// There can be multiple replacement values for each result (1:N
955  /// replacement). If the replacement values are empty, the respective result
956  /// is dropped and a source materialization is built if the result still has
957  /// uses.
958  void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
959 
960  /// Replace the uses of the given value with the given values. The specified
961  /// converter is used to build materializations (if necessary).
962  void replaceAllUsesWith(Value from, ValueRange to,
963  const TypeConverter *converter);
964 
965  /// Erase the given block and its contents.
966  void eraseBlock(Block *block);
967 
968  /// Inline the source block into the destination block before the given
969  /// iterator.
970  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
971 
972  //===--------------------------------------------------------------------===//
973  // Materializations
974  //===--------------------------------------------------------------------===//
975 
976  /// Build an unresolved materialization operation given a range of output
977  /// types and a list of input operands. Returns the inputs if they their
978  /// types match the output types.
979  ///
980  /// If a cast op was built, it can optionally be returned with the `castOp`
981  /// output argument.
982  ///
983  /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
984  /// the results of the unresolved materialization in the conversion value
985  /// mapping.
986  ///
987  /// If `isPureTypeConversion` is "true", the materialization is created only
988  /// to resolve a type mismatch. That means it is not a regular value
989  /// replacement issued by the user. (Replacement values that are created
990  /// "out of thin air" appear like unresolved materializations because they are
991  /// unrealized_conversion_cast ops. However, they must be treated like
992  /// regular value replacements.)
993  ValueRange buildUnresolvedMaterialization(
994  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
995  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
996  Type originalType, const TypeConverter *converter,
997  bool isPureTypeConversion = true);
998 
999  /// Find a replacement value for the given SSA value in the conversion value
1000  /// mapping. The replacement value must have the same type as the given SSA
1001  /// value. If there is no replacement value with the correct type, find the
1002  /// latest replacement value (regardless of the type) and build a source
1003  /// materialization.
1004  Value findOrBuildReplacementValue(Value value,
1005  const TypeConverter *converter);
1006 
1007  //===--------------------------------------------------------------------===//
1008  // Rewriter Notification Hooks
1009  //===--------------------------------------------------------------------===//
1010 
1011  //// Notifies that an op was inserted.
1012  void notifyOperationInserted(Operation *op,
1013  OpBuilder::InsertPoint previous) override;
1014 
1015  /// Notifies that a block was inserted.
1016  void notifyBlockInserted(Block *block, Region *previous,
1017  Region::iterator previousIt) override;
1018 
1019  /// Notifies that a pattern match failed for the given reason.
1020  void
1021  notifyMatchFailure(Location loc,
1022  function_ref<void(Diagnostic &)> reasonCallback) override;
1023 
1024  //===--------------------------------------------------------------------===//
1025  // IR Erasure
1026  //===--------------------------------------------------------------------===//
1027 
1028  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1029  /// operation or block is erased multiple times. This rewriter assumes that
1030  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1032  public:
1034  MLIRContext *context,
1035  std::function<void(Operation *)> opErasedCallback = nullptr)
1036  : RewriterBase(context, /*listener=*/this),
1037  opErasedCallback(opErasedCallback) {}
1038 
1039  /// Erase the given op (unless it was already erased).
1040  void eraseOp(Operation *op) override {
1041  if (wasErased(op))
1042  return;
1043  op->dropAllUses();
1045  }
1046 
1047  /// Erase the given block (unless it was already erased).
1048  void eraseBlock(Block *block) override {
1049  if (wasErased(block))
1050  return;
1051  assert(block->empty() && "expected empty block");
1052  block->dropAllDefinedValueUses();
1053  RewriterBase::eraseBlock(block);
1054  }
1055 
1056  bool wasErased(void *ptr) const { return erased.contains(ptr); }
1057 
1058  void notifyOperationErased(Operation *op) override {
1059  erased.insert(op);
1060  if (opErasedCallback)
1061  opErasedCallback(op);
1062  }
1063 
1064  void notifyBlockErased(Block *block) override { erased.insert(block); }
1065 
1066  private:
1067  /// Pointers to all erased operations and blocks.
1068  DenseSet<void *> erased;
1069 
1070  /// A callback that is invoked when an operation is erased.
1071  std::function<void(Operation *)> opErasedCallback;
1072  };
1073 
1074  //===--------------------------------------------------------------------===//
1075  // State
1076  //===--------------------------------------------------------------------===//
1077 
1078  /// The rewriter that is used to perform the conversion.
1080 
1081  // Mapping between replaced values that differ in type. This happens when
1082  // replacing a value with one of a different type.
1083  ConversionValueMapping mapping;
1084 
1085  /// Ordered list of block operations (creations, splits, motions).
1086  /// This vector is maintained only if `allowPatternRollback` is set to
1087  /// "true". Otherwise, all IR rewrites are materialized immediately and no
1088  /// bookkeeping is needed.
1090 
1091  /// A set of operations that should no longer be considered for legalization.
1092  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1093  /// tracked separately.
1095 
1096  /// A set of operations that were replaced/erased. Such ops are not erased
1097  /// immediately but only when the dialect conversion succeeds. In the mean
1098  /// time, they should no longer be considered for legalization and any attempt
1099  /// to modify/access them is invalid rewriter API usage.
1101 
1102  /// A set of operations that were created by the current pattern.
1104 
1105  /// A set of operations that were modified by the current pattern.
1107 
1108  /// A set of blocks that were inserted (newly-created blocks or moved blocks)
1109  /// by the current pattern.
1111 
1112  /// A list of unresolved materializations that were created by the current
1113  /// pattern.
1115 
1116  /// A mapping for looking up metadata of unresolved materializations.
1119 
1120  /// The current type converter, or nullptr if no type converter is currently
1121  /// active.
1122  const TypeConverter *currentTypeConverter = nullptr;
1123 
1124  /// A mapping of regions to type converters that should be used when
1125  /// converting the arguments of blocks within that region.
1127 
1128  /// Dialect conversion configuration.
1130 
1131  /// A set of erased operations. This set is utilized only if
1132  /// `allowPatternRollback` is set to "false". Conceptually, this set is
1133  /// similar to `replacedOps` (which is maintained when the flag is set to
1134  /// "true"). However, erasing from a DenseSet is more efficient than erasing
1135  /// from a SetVector.
1137 
1138  /// A set of erased blocks. This set is utilized only if
1139  /// `allowPatternRollback` is set to "false".
1141 
1142  /// A rewriter that notifies the listener (if any) about all IR
1143  /// modifications. This rewriter is utilized only if `allowPatternRollback`
1144  /// is set to "false". If the flag is set to "true", the listener is notified
1145  /// with a separate mechanism (e.g., in `IRRewrite::commit`).
1147 
1148 #ifndef NDEBUG
1149  /// A set of replaced values. This set is for debugging purposes only and it
1150  /// is maintained only if `allowPatternRollback` is set to "true".
1152 
1153  /// A set of operations that have pending updates. This tracking isn't
1154  /// strictly necessary, and is thus only active during debug builds for extra
1155  /// verification.
1157 
1158  /// A raw output stream used to prefix the debug log.
1159  llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
1160  llvm::dbgs()};
1161 
1162  /// A logger used to emit diagnostics during the conversion process.
1163  llvm::ScopedPrinter logger{os};
1164  std::string logPrefix;
1165 #endif
1166 };
1167 } // namespace detail
1168 } // namespace mlir
1169 
1170 const ConversionConfig &IRRewrite::getConfig() const {
1171  return rewriterImpl.config;
1172 }
1173 
1174 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1175  // Inform the listener about all IR modifications that have already taken
1176  // place: References to the original block have been replaced with the new
1177  // block.
1178  if (auto *listener =
1179  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1180  for (Operation *op : getNewBlock()->getUsers())
1181  listener->notifyOperationModified(op);
1182 }
1183 
1184 void BlockTypeConversionRewrite::rollback() {
1185  getNewBlock()->replaceAllUsesWith(getOrigBlock());
1186 }
1187 
1188 /// Replace all uses of `from` with `repl`.
1189 static void performReplaceValue(RewriterBase &rewriter, Value from,
1190  Value repl) {
1191  if (isa<BlockArgument>(repl)) {
1192  // `repl` is a block argument. Directly replace all uses.
1193  rewriter.replaceAllUsesWith(from, repl);
1194  return;
1195  }
1196 
1197  // If the replacement value is an operation, only replace those uses that:
1198  // - are in a different block than the replacement operation, or
1199  // - are in the same block but after the replacement operation.
1200  //
1201  // Example:
1202  // ^bb0(%arg0: i32):
1203  // %0 = "consumer"(%arg0) : (i32) -> (i32)
1204  // "another_consumer"(%arg0) : (i32) -> ()
1205  //
1206  // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1207  // use in "another_consumer" but not the use in "consumer". When using the
1208  // normal RewriterBase API, this would typically be done with
1209  // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1210  // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1211  // it cannot be supported efficiently with `allowPatternRollback` set to
1212  // "true". Therefore, the conversion driver is trying to be smart and replaces
1213  // only those uses that do not lead to a dominance violation. E.g., the
1214  // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1215  // behavior.
1216  //
1217  // TODO: As we move more and more towards `allowPatternRollback` set to
1218  // "false", we should remove this special handling, in order to align the
1219  // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1220  Operation *replOp = repl.getDefiningOp();
1221  Block *replBlock = replOp->getBlock();
1222  rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
1223  Operation *user = operand.getOwner();
1224  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1225  });
1226 }
1227 
1228 void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1229  Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
1230  if (!repl)
1231  return;
1232  performReplaceValue(rewriter, value, repl);
1233 }
1234 
1235 void ReplaceValueRewrite::rollback() {
1236  rewriterImpl.mapping.erase({value});
1237 #ifndef NDEBUG
1238  rewriterImpl.replacedValues.erase(value);
1239 #endif // NDEBUG
1240 }
1241 
1242 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1243  auto *listener =
1244  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1245 
1246  // Compute replacement values.
1247  SmallVector<Value> replacements =
1248  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1249  return rewriterImpl.findOrBuildReplacementValue(result, converter);
1250  });
1251 
1252  // Notify the listener that the operation is about to be replaced.
1253  if (listener)
1254  listener->notifyOperationReplaced(op, replacements);
1255 
1256  // Replace all uses with the new values.
1257  for (auto [result, newValue] :
1258  llvm::zip_equal(op->getResults(), replacements))
1259  if (newValue)
1260  rewriter.replaceAllUsesWith(result, newValue);
1261 
1262  // The original op will be erased, so remove it from the set of unlegalized
1263  // ops.
1264  if (getConfig().unlegalizedOps)
1265  getConfig().unlegalizedOps->erase(op);
1266 
1267  // Notify the listener that the operation and its contents are being erased.
1268  if (listener)
1269  notifyIRErased(listener, *op);
1270 
1271  // Do not erase the operation yet. It may still be referenced in `mapping`.
1272  // Just unlink it for now and erase it during cleanup.
1273  op->getBlock()->getOperations().remove(op);
1274 }
1275 
1276 void ReplaceOperationRewrite::rollback() {
1277  for (auto result : op->getResults())
1278  rewriterImpl.mapping.erase({result});
1279 }
1280 
1281 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1282  rewriter.eraseOp(op);
1283 }
1284 
1285 void CreateOperationRewrite::rollback() {
1286  for (Region &region : op->getRegions()) {
1287  while (!region.getBlocks().empty())
1288  region.getBlocks().remove(region.getBlocks().begin());
1289  }
1290  op->dropAllUses();
1291  op->erase();
1292 }
1293 
1294 void UnresolvedMaterializationRewrite::rollback() {
1295  if (!mappedValues.empty())
1296  rewriterImpl.mapping.erase(mappedValues);
1297  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1298  op->erase();
1299 }
1300 
1302  // Commit all rewrites. Use a new rewriter, so the modifications are not
1303  // tracked for rollback purposes etc.
1304  IRRewriter irRewriter(rewriter.getContext(), config.listener);
1305  // Note: New rewrites may be added during the "commit" phase and the
1306  // `rewrites` vector may reallocate.
1307  for (size_t i = 0; i < rewrites.size(); ++i)
1308  rewrites[i]->commit(irRewriter);
1309 
1310  // Clean up all rewrites.
1311  SingleEraseRewriter eraseRewriter(
1312  rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
1313  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1314  unresolvedMaterializations.erase(castOp);
1315  });
1316  for (auto &rewrite : rewrites)
1317  rewrite->cleanup(eraseRewriter);
1318 }
1319 
1320 //===----------------------------------------------------------------------===//
1321 // State Management
1322 //===----------------------------------------------------------------------===//
1323 
1325  Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1326  // Helper function that looks up a single value.
1327  auto lookup = [&](const ValueVector &values) -> ValueVector {
1328  assert(!values.empty() && "expected non-empty value vector");
1329 
1330  // If the pattern rollback is enabled, use the mapping to look up the
1331  // values.
1333  return mapping.lookup(values);
1334 
1335  // Otherwise, look up values by examining the IR. All replacements have
1336  // already been materialized in IR.
1337  Operation *op = getCommonDefiningOp(values);
1338  if (!op)
1339  return {};
1340  auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1341  if (!castOp)
1342  return {};
1343  if (!this->unresolvedMaterializations.contains(castOp))
1344  return {};
1345  if (castOp.getOutputs() != values)
1346  return {};
1347  return castOp.getInputs();
1348  };
1349 
1350  // Helper function that looks up each value in `values` individually and then
1351  // composes the results. If that fails, it tries to look up the entire vector
1352  // at once.
1353  auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1354  // If possible, replace each value with (one or multiple) mapped values.
1355  ValueVector next;
1356  for (Value v : values) {
1357  ValueVector r = lookup({v});
1358  if (!r.empty()) {
1359  llvm::append_range(next, r);
1360  } else {
1361  next.push_back(v);
1362  }
1363  }
1364  if (next != values) {
1365  // At least one value was replaced.
1366  return next;
1367  }
1368 
1369  // Otherwise: Check if there is a mapping for the entire vector. Such
1370  // mappings are materializations. (N:M mapping are not supported for value
1371  // replacements.)
1372  //
1373  // Note: From a correctness point of view, materializations do not have to
1374  // be stored (and looked up) in the mapping. But for performance reasons,
1375  // we choose to reuse existing IR (when possible) instead of creating it
1376  // multiple times.
1377  ValueVector r = lookup(values);
1378  if (r.empty()) {
1379  // No mapping found: The lookup stops here.
1380  return {};
1381  }
1382  return r;
1383  };
1384 
1385  // Try to find the deepest values that have the desired types. If there is no
1386  // such mapping, simply return the deepest values.
1387  ValueVector desiredValue;
1388  ValueVector current{from};
1389  ValueVector lastNonMaterialization{from};
1390  do {
1391  // Store the current value if the types match.
1392  bool match = TypeRange(ValueRange(current)) == desiredTypes;
1393  if (skipPureTypeConversions) {
1394  // Skip pure type conversions, if requested.
1395  bool pureConversion = isPureTypeConversion(current);
1396  match &= !pureConversion;
1397  // Keep track of the last mapped value that was not a pure type
1398  // conversion.
1399  if (!pureConversion)
1400  lastNonMaterialization = current;
1401  }
1402  if (match)
1403  desiredValue = current;
1404 
1405  // Lookup next value in the mapping.
1406  ValueVector next = composedLookup(current);
1407  if (next.empty())
1408  break;
1409  current = std::move(next);
1410  } while (true);
1411 
1412  // If the desired values were found use them, otherwise default to the leaf
1413  // values. (Skip pure type conversions, if requested.)
1414  if (!desiredTypes.empty())
1415  return desiredValue;
1416  if (skipPureTypeConversions)
1417  return lastNonMaterialization;
1418  return current;
1419 }
1420 
1423  TypeRange desiredTypes) const {
1424  ValueVector result = lookupOrDefault(from, desiredTypes);
1425  if (result == ValueVector{from} ||
1426  (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
1427  return {};
1428  return result;
1429 }
1430 
1432  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1433 }
1434 
1436  StringRef patternName) {
1437  // Undo any rewrites.
1438  undoRewrites(state.numRewrites, patternName);
1439 
1440  // Pop all of the recorded ignored operations that are no longer valid.
1441  while (ignoredOps.size() != state.numIgnoredOperations)
1442  ignoredOps.pop_back();
1443 
1444  while (replacedOps.size() != state.numReplacedOps)
1445  replacedOps.pop_back();
1446 }
1447 
1448 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1449  StringRef patternName) {
1450  for (auto &rewrite :
1451  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1452  rewrite->rollback();
1453  rewrites.resize(numRewritesToKeep);
1454 }
1455 
1457  StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1458  SmallVector<ValueVector> &remapped) {
1459  remapped.reserve(llvm::size(values));
1460 
1461  for (const auto &it : llvm::enumerate(values)) {
1462  Value operand = it.value();
1463  Type origType = operand.getType();
1464  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1465 
1466  if (!currentTypeConverter) {
1467  // The current pattern does not have a type converter. Pass the most
1468  // recently mapped values, excluding materializations. Materializations
1469  // are intentionally excluded because their presence may depend on other
1470  // patterns. Including materializations would make the lookup fragile
1471  // and unpredictable.
1472  remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1473  /*skipPureTypeConversions=*/true));
1474  continue;
1475  }
1476 
1477  // If there is no legal conversion, fail to match this pattern.
1478  SmallVector<Type, 1> legalTypes;
1479  if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
1480  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1481  diag << "unable to convert type for " << valueDiagTag << " #"
1482  << it.index() << ", type was " << origType;
1483  });
1484  return failure();
1485  }
1486  // If a type is converted to 0 types, there is nothing to do.
1487  if (legalTypes.empty()) {
1488  remapped.push_back({});
1489  continue;
1490  }
1491 
1492  ValueVector repl = lookupOrDefault(operand, legalTypes);
1493  if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1494  // Mapped values have the correct type or there is an existing
1495  // materialization. Or the operand is not mapped at all and has the
1496  // correct type.
1497  remapped.push_back(std::move(repl));
1498  continue;
1499  }
1500 
1501  // Create a materialization for the most recently mapped values.
1502  repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1503  /*skipPureTypeConversions=*/true);
1505  MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1506  /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1507  /*originalType=*/origType, currentTypeConverter);
1508  remapped.push_back(castValues);
1509  }
1510  return success();
1511 }
1512 
1514  // Check to see if this operation is ignored or was replaced.
1515  return wasOpReplaced(op) || ignoredOps.count(op);
1516 }
1517 
1519  // Check to see if this operation was replaced.
1520  return replacedOps.count(op) || erasedOps.count(op);
1521 }
1522 
1523 //===----------------------------------------------------------------------===//
1524 // Type Conversion
1525 //===----------------------------------------------------------------------===//
1526 
1528  Region *region, const TypeConverter &converter,
1529  TypeConverter::SignatureConversion *entryConversion) {
1530  regionToConverter[region] = &converter;
1531  if (region->empty())
1532  return nullptr;
1533 
1534  // Convert the arguments of each non-entry block within the region.
1535  for (Block &block :
1536  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1537  // Compute the signature for the block with the provided converter.
1538  std::optional<TypeConverter::SignatureConversion> conversion =
1539  converter.convertBlockSignature(&block);
1540  if (!conversion)
1541  return failure();
1542  // Convert the block with the computed signature.
1543  applySignatureConversion(&block, &converter, *conversion);
1544  }
1545 
1546  // Convert the entry block. If an entry signature conversion was provided,
1547  // use that one. Otherwise, compute the signature with the type converter.
1548  if (entryConversion)
1549  return applySignatureConversion(&region->front(), &converter,
1550  *entryConversion);
1551  std::optional<TypeConverter::SignatureConversion> conversion =
1552  converter.convertBlockSignature(&region->front());
1553  if (!conversion)
1554  return failure();
1555  return applySignatureConversion(&region->front(), &converter, *conversion);
1556 }
1557 
1559  Block *block, const TypeConverter *converter,
1560  TypeConverter::SignatureConversion &signatureConversion) {
1561 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1562  // A block cannot be converted multiple times.
1563  if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1564  llvm::report_fatal_error("block was already converted");
1565 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1566 
1568 
1569  // If no arguments are being changed or added, there is nothing to do.
1570  unsigned origArgCount = block->getNumArguments();
1571  auto convertedTypes = signatureConversion.getConvertedTypes();
1572  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1573  return block;
1574 
1575  // Compute the locations of all block arguments in the new block.
1576  SmallVector<Location> newLocs(convertedTypes.size(),
1578  for (unsigned i = 0; i < origArgCount; ++i) {
1579  auto inputMap = signatureConversion.getInputMapping(i);
1580  if (!inputMap || inputMap->replacedWithValues())
1581  continue;
1582  Location origLoc = block->getArgument(i).getLoc();
1583  for (unsigned j = 0; j < inputMap->size; ++j)
1584  newLocs[inputMap->inputNo + j] = origLoc;
1585  }
1586 
1587  // Insert a new block with the converted block argument types and move all ops
1588  // from the old block to the new block.
1589  Block *newBlock =
1590  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1591  convertedTypes, newLocs);
1592 
1593  // If a listener is attached to the dialect conversion, ops cannot be moved
1594  // to the destination block in bulk ("fast path"). This is because at the time
1595  // the notifications are sent, it is unknown which ops were moved. Instead,
1596  // ops should be moved one-by-one ("slow path"), so that a separate
1597  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1598  // a bit more efficient, so we try to do that when possible.
1599  bool fastPath = !config.listener;
1600  if (fastPath) {
1602  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1603  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1604  } else {
1605  while (!block->empty())
1606  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1607  }
1608 
1609  // Replace all uses of the old block with the new block.
1610  block->replaceAllUsesWith(newBlock);
1611 
1612  for (unsigned i = 0; i != origArgCount; ++i) {
1613  BlockArgument origArg = block->getArgument(i);
1614  Type origArgType = origArg.getType();
1615 
1616  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1617  signatureConversion.getInputMapping(i);
1618  if (!inputMap) {
1619  // This block argument was dropped and no replacement value was provided.
1620  // Materialize a replacement value "out of thin air".
1621  // Note: Materialization must be built here because we cannot find a
1622  // valid insertion point in the new block. (Will point to the old block.)
1623  Value mat =
1625  MaterializationKind::Source,
1626  OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1627  origArg.getLoc(),
1628  /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1629  /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1630  /*isPureTypeConversion=*/false)
1631  .front();
1632  replaceAllUsesWith(origArg, mat, converter);
1633  continue;
1634  }
1635 
1636  if (inputMap->replacedWithValues()) {
1637  // This block argument was dropped and replacement values were provided.
1638  assert(inputMap->size == 0 &&
1639  "invalid to provide a replacement value when the argument isn't "
1640  "dropped");
1641  replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
1642  continue;
1643  }
1644 
1645  // This is a 1->1+ mapping.
1646  auto replArgs =
1647  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1648  replaceAllUsesWith(origArg, replArgs, converter);
1649  }
1650 
1652  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1653 
1654  // Erase the old block. (It is just unlinked for now and will be erased during
1655  // cleanup.)
1656  rewriter.eraseBlock(block);
1657 
1658  return newBlock;
1659 }
1660 
1661 //===----------------------------------------------------------------------===//
1662 // Materializations
1663 //===----------------------------------------------------------------------===//
1664 
1665 /// Build an unresolved materialization operation given an output type and set
1666 /// of input operands.
1668  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1669  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1670  Type originalType, const TypeConverter *converter,
1671  bool isPureTypeConversion) {
1672  assert((!originalType || kind == MaterializationKind::Target) &&
1673  "original type is valid only for target materializations");
1674  assert(TypeRange(inputs) != outputTypes &&
1675  "materialization is not necessary");
1676 
1677  // Create an unresolved materialization. We use a new OpBuilder to avoid
1678  // tracking the materialization like we do for other operations.
1679  OpBuilder builder(outputTypes.front().getContext());
1680  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1681  UnrealizedConversionCastOp convertOp =
1682  UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1684  StringRef kindStr =
1685  kind == MaterializationKind::Source ? "source" : "target";
1686  convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1687  }
1689  convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
1690 
1691  // Register the materialization.
1692  unresolvedMaterializations[convertOp] =
1693  UnresolvedMaterializationInfo(converter, kind, originalType);
1695  if (!valuesToMap.empty())
1696  mapping.map(valuesToMap, convertOp.getResults());
1697  appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1698  std::move(valuesToMap));
1699  } else {
1700  patternMaterializations.insert(convertOp);
1701  }
1702  return convertOp.getResults();
1703 }
1704 
1706  Value value, const TypeConverter *converter) {
1707  assert(config.allowPatternRollback &&
1708  "this code path is valid only in rollback mode");
1709 
1710  // Try to find a replacement value with the same type in the conversion value
1711  // mapping. This includes cached materializations. We try to reuse those
1712  // instead of generating duplicate IR.
1713  ValueVector repl = lookupOrNull(value, value.getType());
1714  if (!repl.empty())
1715  return repl.front();
1716 
1717  // Check if the value is dead. No replacement value is needed in that case.
1718  // This is an approximate check that may have false negatives but does not
1719  // require computing and traversing an inverse mapping. (We may end up
1720  // building source materializations that are never used and that fold away.)
1721  if (llvm::all_of(value.getUsers(),
1722  [&](Operation *op) { return replacedOps.contains(op); }) &&
1723  !mapping.isMappedTo(value))
1724  return Value();
1725 
1726  // No replacement value was found. Get the latest replacement value
1727  // (regardless of the type) and build a source materialization to the
1728  // original type.
1729  repl = lookupOrNull(value);
1730 
1731  // Compute the insertion point of the materialization.
1733  if (repl.empty()) {
1734  // The source materialization has no inputs. Insert it right before the
1735  // value that it is replacing.
1736  ip = computeInsertPoint(value);
1737  } else {
1738  // Compute the "earliest" insertion point at which all values in `repl` are
1739  // defined. It is important to emit the materialization at that location
1740  // because the same materialization may be reused in a different context.
1741  // (That's because materializations are cached in the conversion value
1742  // mapping.) The insertion point of the materialization must be valid for
1743  // all future users that may be created later in the conversion process.
1744  ip = computeInsertPoint(repl);
1745  }
1747  MaterializationKind::Source, ip, value.getLoc(),
1748  /*valuesToMap=*/repl, /*inputs=*/repl,
1749  /*outputTypes=*/value.getType(),
1750  /*originalType=*/Type(), converter,
1751  /*isPureTypeConversion=*/!repl.empty())
1752  .front();
1753  return castValue;
1754 }
1755 
1756 //===----------------------------------------------------------------------===//
1757 // Rewriter Notification Hooks
1758 //===----------------------------------------------------------------------===//
1759 
1761  Operation *op, OpBuilder::InsertPoint previous) {
1762  // If no previous insertion point is provided, the op used to be detached.
1763  bool wasDetached = !previous.isSet();
1764  LLVM_DEBUG({
1765  logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
1766  << ")";
1767  if (wasDetached)
1768  logger.getOStream() << " (was detached)";
1769  logger.getOStream() << "\n";
1770  });
1771 
1772  // In rollback mode, it is easier to misuse the API, so perform extra error
1773  // checking.
1774  assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
1775  "attempting to insert into a block within a replaced/erased op");
1776 
1777  // In "no rollback" mode, the listener is always notified immediately.
1779  config.listener->notifyOperationInserted(op, previous);
1780 
1781  if (wasDetached) {
1782  // If the op was detached, it is most likely a newly created op. Add it the
1783  // set of newly created ops, so that it will be legalized. If this op is
1784  // not a newly created op, it will be legalized a second time, which is
1785  // inefficient but harmless.
1786  patternNewOps.insert(op);
1787 
1789  // TODO: If the same op is inserted multiple times from a detached
1790  // state, the rollback mechanism may erase the same op multiple times.
1791  // This is a bug in the rollback-based dialect conversion driver.
1792  appendRewrite<CreateOperationRewrite>(op);
1793  } else {
1794  // In "no rollback" mode, there is an extra data structure for tracking
1795  // erased operations that must be kept up to date.
1796  erasedOps.erase(op);
1797  }
1798  return;
1799  }
1800 
1801  // The op was moved from one place to another.
1803  appendRewrite<MoveOperationRewrite>(op, previous);
1804 }
1805 
1806 /// Given that `fromRange` is about to be replaced with `toRange`, compute
1807 /// replacement values with the types of `fromRange`.
1808 static SmallVector<Value>
1810  const SmallVector<SmallVector<Value>> &toRange,
1811  const TypeConverter *converter) {
1812  assert(!impl.config.allowPatternRollback &&
1813  "this code path is valid only in 'no rollback' mode");
1814  SmallVector<Value> repls;
1815  for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1816  if (from.use_empty()) {
1817  // The replaced value is dead. No replacement value is needed.
1818  repls.push_back(Value());
1819  continue;
1820  }
1821 
1822  if (to.empty()) {
1823  // The replaced value is dropped. Materialize a replacement value "out of
1824  // thin air".
1825  Value srcMat = impl.buildUnresolvedMaterialization(
1826  MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
1827  /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1828  /*outputTypes=*/from.getType(), /*originalType=*/Type(),
1829  converter)[0];
1830  repls.push_back(srcMat);
1831  continue;
1832  }
1833 
1834  if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
1835  // The replacement value already has the correct type. Use it directly.
1836  repls.push_back(to[0]);
1837  continue;
1838  }
1839 
1840  // The replacement value has the wrong type. Build a source materialization
1841  // to the original type.
1842  // TODO: This is a bit inefficient. We should try to reuse existing
1843  // materializations if possible. This would require an extension of the
1844  // `lookupOrDefault` API.
1845  Value srcMat = impl.buildUnresolvedMaterialization(
1846  MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
1847  /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
1848  /*originalType=*/Type(), converter)[0];
1849  repls.push_back(srcMat);
1850  }
1851 
1852  return repls;
1853 }
1854 
1856  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1857  assert(newValues.size() == op->getNumResults() &&
1858  "incorrect number of replacement values");
1859 
1861  // Pattern rollback is not allowed: materialize all IR changes immediately.
1863  *this, op->getResults(), newValues, currentTypeConverter);
1864  // Update internal data structures, so that there are no dangling pointers
1865  // to erased IR.
1866  op->walk([&](Operation *op) {
1867  erasedOps.insert(op);
1868  ignoredOps.remove(op);
1869  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1870  unresolvedMaterializations.erase(castOp);
1871  patternMaterializations.erase(castOp);
1872  }
1873  // The original op will be erased, so remove it from the set of
1874  // unlegalized ops.
1875  if (config.unlegalizedOps)
1876  config.unlegalizedOps->erase(op);
1877  });
1878  op->walk([&](Block *block) { erasedBlocks.insert(block); });
1879  // Replace the op with the replacement values and notify the listener.
1880  notifyingRewriter.replaceOp(op, repls);
1881  return;
1882  }
1883 
1884  assert(!ignoredOps.contains(op) && "operation was already replaced");
1885 #ifndef NDEBUG
1886  for (Value v : op->getResults())
1887  assert(!replacedValues.contains(v) &&
1888  "attempting to replace a value that was already replaced");
1889 #endif // NDEBUG
1890 
1891  // Check if replaced op is an unresolved materialization, i.e., an
1892  // unrealized_conversion_cast op that was created by the conversion driver.
1893  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1894  // Make sure that the user does not mess with unresolved materializations
1895  // that were inserted by the conversion driver. We keep track of these
1896  // ops in internal data structures.
1897  assert(!unresolvedMaterializations.contains(castOp) &&
1898  "attempting to replace/erase an unresolved materialization");
1899  }
1900 
1901  // Create mappings for each of the new result values.
1902  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
1903  mapping.map(static_cast<Value>(result), std::move(repl));
1904 
1905  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1906  // Mark this operation and all nested ops as replaced.
1907  op->walk([&](Operation *op) { replacedOps.insert(op); });
1908 }
1909 
1911  Value from, ValueRange to, const TypeConverter *converter) {
1913  SmallVector<Value> toConv = llvm::to_vector(to);
1914  SmallVector<Value> repls =
1915  getReplacementValues(*this, from, {toConv}, converter);
1916  IRRewriter r(from.getContext());
1917  Value repl = repls.front();
1918  if (!repl)
1919  return;
1920 
1921  performReplaceValue(r, from, repl);
1922  return;
1923  }
1924 
1925 #ifndef NDEBUG
1926  // Make sure that a value is not replaced multiple times. In rollback mode,
1927  // `replaceAllUsesWith` replaces not only all current uses of the given value,
1928  // but also all future uses that may be introduced by future pattern
1929  // applications. Therefore, it does not make sense to call
1930  // `replaceAllUsesWith` multiple times with the same value. Doing so would
1931  // overwrite the mapping and mess with the internal state of the dialect
1932  // conversion driver.
1933  assert(!replacedValues.contains(from) &&
1934  "attempting to replace a value that was already replaced");
1935  assert(!wasOpReplaced(from.getDefiningOp()) &&
1936  "attempting to replace a op result that was already replaced");
1937  replacedValues.insert(from);
1938 #endif // NDEBUG
1939 
1940  mapping.map(from, to);
1941  appendRewrite<ReplaceValueRewrite>(from, converter);
1942 }
1943 
1946  // Pattern rollback is not allowed: materialize all IR changes immediately.
1947  // Update internal data structures, so that there are no dangling pointers
1948  // to erased IR.
1949  block->walk([&](Operation *op) {
1950  erasedOps.insert(op);
1951  ignoredOps.remove(op);
1952  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1953  unresolvedMaterializations.erase(castOp);
1954  patternMaterializations.erase(castOp);
1955  }
1956  // The original op will be erased, so remove it from the set of
1957  // unlegalized ops.
1958  if (config.unlegalizedOps)
1959  config.unlegalizedOps->erase(op);
1960  });
1961  block->walk([&](Block *block) { erasedBlocks.insert(block); });
1962  // Erase the block and notify the listener.
1964  return;
1965  }
1966 
1967  assert(!wasOpReplaced(block->getParentOp()) &&
1968  "attempting to erase a block within a replaced/erased op");
1969  appendRewrite<EraseBlockRewrite>(block);
1970 
1971  // Unlink the block from its parent region. The block is kept in the rewrite
1972  // object and will be actually destroyed when rewrites are applied. This
1973  // allows us to keep the operations in the block live and undo the removal by
1974  // re-inserting the block.
1975  block->getParent()->getBlocks().remove(block);
1976 
1977  // Mark all nested ops as erased.
1978  block->walk([&](Operation *op) { replacedOps.insert(op); });
1979 }
1980 
1982  Block *block, Region *previous, Region::iterator previousIt) {
1983  // If no previous insertion point is provided, the block used to be detached.
1984  bool wasDetached = !previous;
1985  Operation *newParentOp = block->getParentOp();
1986  LLVM_DEBUG(
1987  {
1988  Operation *parent = newParentOp;
1989  if (parent) {
1990  logger.startLine() << "** Insert Block into : '" << parent->getName()
1991  << "' (" << parent << ")";
1992  } else {
1993  logger.startLine()
1994  << "** Insert Block into detached Region (nullptr parent op)";
1995  }
1996  if (wasDetached)
1997  logger.getOStream() << " (was detached)";
1998  logger.getOStream() << "\n";
1999  });
2000 
2001  // In rollback mode, it is easier to misuse the API, so perform extra error
2002  // checking.
2003  assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
2004  "attempting to insert into a region within a replaced/erased op");
2005  (void)newParentOp;
2006 
2007  // In "no rollback" mode, the listener is always notified immediately.
2009  config.listener->notifyBlockInserted(block, previous, previousIt);
2010 
2011  patternInsertedBlocks.insert(block);
2012 
2013  if (wasDetached) {
2014  // If the block was detached, it is most likely a newly created block.
2016  // TODO: If the same block is inserted multiple times from a detached
2017  // state, the rollback mechanism may erase the same block multiple times.
2018  // This is a bug in the rollback-based dialect conversion driver.
2019  appendRewrite<CreateBlockRewrite>(block);
2020  } else {
2021  // In "no rollback" mode, there is an extra data structure for tracking
2022  // erased blocks that must be kept up to date.
2023  erasedBlocks.erase(block);
2024  }
2025  return;
2026  }
2027 
2028  // The block was moved from one place to another.
2030  appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2031 }
2032 
2034  Block *dest,
2035  Block::iterator before) {
2036  appendRewrite<InlineBlockRewrite>(dest, source, before);
2037 }
2038 
2040  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
2041  LLVM_DEBUG({
2043  reasonCallback(diag);
2044  logger.startLine() << "** Failure : " << diag.str() << "\n";
2045  if (config.notifyCallback)
2047  });
2048 }
2049 
2050 //===----------------------------------------------------------------------===//
2051 // ConversionPatternRewriter
2052 //===----------------------------------------------------------------------===//
2053 
2054 ConversionPatternRewriter::ConversionPatternRewriter(
2055  MLIRContext *ctx, const ConversionConfig &config)
2056  : PatternRewriter(ctx),
2057  impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
2058  setListener(impl.get());
2059 }
2060 
2062 
2064  return impl->config;
2065 }
2066 
2068  assert(op && newOp && "expected non-null op");
2069  replaceOp(op, newOp->getResults());
2070 }
2071 
2073  assert(op->getNumResults() == newValues.size() &&
2074  "incorrect # of replacement values");
2075  LLVM_DEBUG({
2076  impl->logger.startLine()
2077  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
2078  });
2079 
2080  // If the current insertion point is before the erased operation, we adjust
2081  // the insertion point to be after the operation.
2082  if (getInsertionPoint() == op->getIterator())
2084 
2086  llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2087  return v ? SmallVector<Value>{v} : SmallVector<Value>();
2088  });
2089  impl->replaceOp(op, std::move(newVals));
2090 }
2091 
2093  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2094  assert(op->getNumResults() == newValues.size() &&
2095  "incorrect # of replacement values");
2096  LLVM_DEBUG({
2097  impl->logger.startLine()
2098  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
2099  });
2100 
2101  // If the current insertion point is before the erased operation, we adjust
2102  // the insertion point to be after the operation.
2103  if (getInsertionPoint() == op->getIterator())
2105 
2106  impl->replaceOp(op, std::move(newValues));
2107 }
2108 
2110  LLVM_DEBUG({
2111  impl->logger.startLine()
2112  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
2113  });
2114 
2115  // If the current insertion point is before the erased operation, we adjust
2116  // the insertion point to be after the operation.
2117  if (getInsertionPoint() == op->getIterator())
2119 
2120  SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
2121  impl->replaceOp(op, std::move(nullRepls));
2122 }
2123 
2125  impl->eraseBlock(block);
2126 }
2127 
2129  Block *block, TypeConverter::SignatureConversion &conversion,
2130  const TypeConverter *converter) {
2131  assert(!impl->wasOpReplaced(block->getParentOp()) &&
2132  "attempting to apply a signature conversion to a block within a "
2133  "replaced/erased op");
2134  return impl->applySignatureConversion(block, converter, conversion);
2135 }
2136 
2138  Region *region, const TypeConverter &converter,
2139  TypeConverter::SignatureConversion *entryConversion) {
2140  assert(!impl->wasOpReplaced(region->getParentOp()) &&
2141  "attempting to apply a signature conversion to a block within a "
2142  "replaced/erased op");
2143  return impl->convertRegionTypes(region, converter, entryConversion);
2144 }
2145 
2147  LLVM_DEBUG({
2148  impl->logger.startLine() << "** Replace Value : '" << from << "'";
2149  if (auto blockArg = dyn_cast<BlockArgument>(from)) {
2150  if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
2151  impl->logger.getOStream() << " (in region of '" << parentOp->getName()
2152  << "' (" << parentOp << ")\n";
2153  } else {
2154  impl->logger.getOStream() << " (unlinked block)\n";
2155  }
2156  }
2157  });
2158  impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
2159 }
2160 
2162  SmallVector<ValueVector> remappedValues;
2163  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
2164  remappedValues)))
2165  return nullptr;
2166  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
2167  return remappedValues.front().front();
2168 }
2169 
2170 LogicalResult
2172  SmallVectorImpl<Value> &results) {
2173  if (keys.empty())
2174  return success();
2175  SmallVector<ValueVector> remapped;
2176  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
2177  remapped)))
2178  return failure();
2179  for (const auto &values : remapped) {
2180  assert(values.size() == 1 && "1:N conversion not supported");
2181  results.push_back(values.front());
2182  }
2183  return success();
2184 }
2185 
2187  Block::iterator before,
2188  ValueRange argValues) {
2189 #ifndef NDEBUG
2190  assert(argValues.size() == source->getNumArguments() &&
2191  "incorrect # of argument replacement values");
2192  assert(!impl->wasOpReplaced(source->getParentOp()) &&
2193  "attempting to inline a block from a replaced/erased op");
2194  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
2195  "attempting to inline a block into a replaced/erased op");
2196  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
2197  // The source block will be deleted, so it should not have any users (i.e.,
2198  // there should be no predecessors).
2199  assert(llvm::all_of(source->getUsers(), opIgnored) &&
2200  "expected 'source' to have no predecessors");
2201 #endif // NDEBUG
2202 
2203  // If a listener is attached to the dialect conversion, ops cannot be moved
2204  // to the destination block in bulk ("fast path"). This is because at the time
2205  // the notifications are sent, it is unknown which ops were moved. Instead,
2206  // ops should be moved one-by-one ("slow path"), so that a separate
2207  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
2208  // a bit more efficient, so we try to do that when possible.
2209  bool fastPath = !getConfig().listener;
2210 
2211  if (fastPath && impl->config.allowPatternRollback)
2212  impl->inlineBlockBefore(source, dest, before);
2213 
2214  // Replace all uses of block arguments.
2215  for (auto it : llvm::zip(source->getArguments(), argValues))
2216  replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2217 
2218  if (fastPath) {
2219  // Move all ops at once.
2220  dest->getOperations().splice(before, source->getOperations());
2221  } else {
2222  // Move op by op.
2223  while (!source->empty())
2224  moveOpBefore(&source->front(), dest, before);
2225  }
2226 
2227  // If the current insertion point is within the source block, adjust the
2228  // insertion point to the destination block.
2229  if (getInsertionBlock() == source)
2230  setInsertionPoint(dest, getInsertionPoint());
2231 
2232  // Erase the source block.
2233  eraseBlock(source);
2234 }
2235 
2237  if (!impl->config.allowPatternRollback) {
2238  // Pattern rollback is not allowed: no extra bookkeeping is needed.
2240  return;
2241  }
2242  assert(!impl->wasOpReplaced(op) &&
2243  "attempting to modify a replaced/erased op");
2244 #ifndef NDEBUG
2245  impl->pendingRootUpdates.insert(op);
2246 #endif
2247  impl->appendRewrite<ModifyOperationRewrite>(op);
2248 }
2249 
2251  impl->patternModifiedOps.insert(op);
2252  if (!impl->config.allowPatternRollback) {
2254  if (getConfig().listener)
2255  getConfig().listener->notifyOperationModified(op);
2256  return;
2257  }
2258 
2259  // There is nothing to do here, we only need to track the operation at the
2260  // start of the update.
2261 #ifndef NDEBUG
2262  assert(!impl->wasOpReplaced(op) &&
2263  "attempting to modify a replaced/erased op");
2264  assert(impl->pendingRootUpdates.erase(op) &&
2265  "operation did not have a pending in-place update");
2266 #endif
2267 }
2268 
2270  if (!impl->config.allowPatternRollback) {
2272  return;
2273  }
2274 #ifndef NDEBUG
2275  assert(impl->pendingRootUpdates.erase(op) &&
2276  "operation did not have a pending in-place update");
2277 #endif
2278  // Erase the last update for this operation.
2279  auto it = llvm::find_if(
2280  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
2281  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2282  return modifyRewrite && modifyRewrite->getOperation() == op;
2283  });
2284  assert(it != impl->rewrites.rend() && "no root update started on op");
2285  (*it)->rollback();
2286  int updateIdx = std::prev(impl->rewrites.rend()) - it;
2287  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2288 }
2289 
2291  return *impl;
2292 }
2293 
2294 //===----------------------------------------------------------------------===//
2295 // ConversionPattern
2296 //===----------------------------------------------------------------------===//
2297 
2299  ArrayRef<ValueRange> operands) const {
2300  SmallVector<Value> oneToOneOperands;
2301  oneToOneOperands.reserve(operands.size());
2302  for (ValueRange operand : operands) {
2303  if (operand.size() != 1)
2304  return failure();
2305 
2306  oneToOneOperands.push_back(operand.front());
2307  }
2308  return std::move(oneToOneOperands);
2309 }
2310 
2311 LogicalResult
2313  PatternRewriter &rewriter) const {
2314  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
2315  auto &rewriterImpl = dialectRewriter.getImpl();
2316 
2317  // Track the current conversion pattern type converter in the rewriter.
2318  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
2319  getTypeConverter());
2320 
2321  // Remap the operands of the operation.
2322  SmallVector<ValueVector> remapped;
2323  if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
2324  op->getOperands(), remapped))) {
2325  return failure();
2326  }
2327  SmallVector<ValueRange> remappedAsRange =
2328  llvm::to_vector_of<ValueRange>(remapped);
2329  return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2330 }
2331 
2332 //===----------------------------------------------------------------------===//
2333 // OperationLegalizer
2334 //===----------------------------------------------------------------------===//
2335 
2336 namespace {
2337 /// A set of rewrite patterns that can be used to legalize a given operation.
2338 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2339 
2340 /// This class defines a recursive operation legalizer.
2341 class OperationLegalizer {
2342 public:
2343  using LegalizationAction = ConversionTarget::LegalizationAction;
2344 
2345  OperationLegalizer(ConversionPatternRewriter &rewriter,
2346  const ConversionTarget &targetInfo,
2348 
2349  /// Returns true if the given operation is known to be illegal on the target.
2350  bool isIllegal(Operation *op) const;
2351 
2352  /// Attempt to legalize the given operation. Returns success if the operation
2353  /// was legalized, failure otherwise.
2354  LogicalResult legalize(Operation *op);
2355 
2356  /// Returns the conversion target in use by the legalizer.
2357  const ConversionTarget &getTarget() { return target; }
2358 
2359 private:
2360  /// Attempt to legalize the given operation by folding it.
2361  LogicalResult legalizeWithFold(Operation *op);
2362 
2363  /// Attempt to legalize the given operation by applying a pattern. Returns
2364  /// success if the operation was legalized, failure otherwise.
2365  LogicalResult legalizeWithPattern(Operation *op);
2366 
2367  /// Return true if the given pattern may be applied to the given operation,
2368  /// false otherwise.
2369  bool canApplyPattern(Operation *op, const Pattern &pattern);
2370 
2371  /// Legalize the resultant IR after successfully applying the given pattern.
2372  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
2373  const RewriterState &curState,
2374  const SetVector<Operation *> &newOps,
2375  const SetVector<Operation *> &modifiedOps,
2376  const SetVector<Block *> &insertedBlocks);
2377 
2378  /// Legalizes the actions registered during the execution of a pattern.
2379  LogicalResult
2380  legalizePatternBlockRewrites(Operation *op,
2381  const SetVector<Block *> &insertedBlocks,
2382  const SetVector<Operation *> &newOps);
2383  LogicalResult
2384  legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
2385  LogicalResult
2386  legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
2387 
2388  //===--------------------------------------------------------------------===//
2389  // Cost Model
2390  //===--------------------------------------------------------------------===//
2391 
2392  /// Build an optimistic legalization graph given the provided patterns. This
2393  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2394  /// patterns for operations that are not directly legal, but may be
2395  /// transitively legal for the current target given the provided patterns.
2396  void buildLegalizationGraph(
2397  LegalizationPatterns &anyOpLegalizerPatterns,
2399 
2400  /// Compute the benefit of each node within the computed legalization graph.
2401  /// This orders the patterns within 'legalizerPatterns' based upon two
2402  /// criteria:
2403  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2404  /// represent the more direct mapping to the target.
2405  /// 2) When comparing patterns with the same legalization depth, prefer the
2406  /// pattern with the highest PatternBenefit. This allows for users to
2407  /// prefer specific legalizations over others.
2408  void computeLegalizationGraphBenefit(
2409  LegalizationPatterns &anyOpLegalizerPatterns,
2411 
2412  /// Compute the legalization depth when legalizing an operation of the given
2413  /// type.
2414  unsigned computeOpLegalizationDepth(
2415  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2417 
2418  /// Apply the conversion cost model to the given set of patterns, and return
2419  /// the smallest legalization depth of any of the patterns. See
2420  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2421  unsigned applyCostModelToPatterns(
2422  LegalizationPatterns &patterns,
2423  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2425 
2426  /// The current set of patterns that have been applied.
2427  SmallPtrSet<const Pattern *, 8> appliedPatterns;
2428 
2429  /// The rewriter to use when converting operations.
2430  ConversionPatternRewriter &rewriter;
2431 
2432  /// The legalization information provided by the target.
2433  const ConversionTarget &target;
2434 
2435  /// The pattern applicator to use for conversions.
2436  PatternApplicator applicator;
2437 };
2438 } // namespace
2439 
2440 OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2441  const ConversionTarget &targetInfo,
2443  : rewriter(rewriter), target(targetInfo), applicator(patterns) {
2444  // The set of patterns that can be applied to illegal operations to transform
2445  // them into legal ones.
2447  LegalizationPatterns anyOpLegalizerPatterns;
2448 
2449  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2450  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2451 }
2452 
2453 bool OperationLegalizer::isIllegal(Operation *op) const {
2454  return target.isIllegal(op);
2455 }
2456 
2457 LogicalResult OperationLegalizer::legalize(Operation *op) {
2458 #ifndef NDEBUG
2459  const char *logLineComment =
2460  "//===-------------------------------------------===//\n";
2461 
2462  auto &logger = rewriter.getImpl().logger;
2463 #endif
2464 
2465  // Check to see if the operation is ignored and doesn't need to be converted.
2466  bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2467 
2468  LLVM_DEBUG({
2469  logger.getOStream() << "\n";
2470  logger.startLine() << logLineComment;
2471  logger.startLine() << "Legalizing operation : ";
2472  // Do not print the operation name if the operation is ignored. Ignored ops
2473  // may have been erased and should not be accessed. The pointer can be
2474  // printed safely.
2475  if (!isIgnored)
2476  logger.getOStream() << "'" << op->getName() << "' ";
2477  logger.getOStream() << "(" << op << ") {\n";
2478  logger.indent();
2479 
2480  // If the operation has no regions, just print it here.
2481  if (!isIgnored && op->getNumRegions() == 0) {
2482  logger.startLine() << OpWithFlags(op,
2483  OpPrintingFlags().printGenericOpForm())
2484  << "\n";
2485  }
2486  });
2487 
2488  if (isIgnored) {
2489  LLVM_DEBUG({
2490  logSuccess(logger, "operation marked 'ignored' during conversion");
2491  logger.startLine() << logLineComment;
2492  });
2493  return success();
2494  }
2495 
2496  // Check if this operation is legal on the target.
2497  if (auto legalityInfo = target.isLegal(op)) {
2498  LLVM_DEBUG({
2499  logSuccess(
2500  logger, "operation marked legal by the target{0}",
2501  legalityInfo->isRecursivelyLegal
2502  ? "; NOTE: operation is recursively legal; skipping internals"
2503  : "");
2504  logger.startLine() << logLineComment;
2505  });
2506 
2507  // If this operation is recursively legal, mark its children as ignored so
2508  // that we don't consider them for legalization.
2509  if (legalityInfo->isRecursivelyLegal) {
2510  op->walk([&](Operation *nested) {
2511  if (op != nested)
2512  rewriter.getImpl().ignoredOps.insert(nested);
2513  });
2514  }
2515 
2516  return success();
2517  }
2518 
2519  // If the operation is not legal, try to fold it in-place if the folding mode
2520  // is 'BeforePatterns'. 'Never' will skip this.
2521  const ConversionConfig &config = rewriter.getConfig();
2523  if (succeeded(legalizeWithFold(op))) {
2524  LLVM_DEBUG({
2525  logSuccess(logger, "operation was folded");
2526  logger.startLine() << logLineComment;
2527  });
2528  return success();
2529  }
2530  }
2531 
2532  // Otherwise, we need to apply a legalization pattern to this operation.
2533  if (succeeded(legalizeWithPattern(op))) {
2534  LLVM_DEBUG({
2535  logSuccess(logger, "");
2536  logger.startLine() << logLineComment;
2537  });
2538  return success();
2539  }
2540 
2541  // If the operation can't be legalized via patterns, try to fold it in-place
2542  // if the folding mode is 'AfterPatterns'.
2544  if (succeeded(legalizeWithFold(op))) {
2545  LLVM_DEBUG({
2546  logSuccess(logger, "operation was folded");
2547  logger.startLine() << logLineComment;
2548  });
2549  return success();
2550  }
2551  }
2552 
2553  LLVM_DEBUG({
2554  logFailure(logger, "no matched legalization pattern");
2555  logger.startLine() << logLineComment;
2556  });
2557  return failure();
2558 }
2559 
2560 /// Helper function that moves and returns the given object. Also resets the
2561 /// original object, so that it is in a valid, empty state again.
2562 template <typename T>
2563 static T moveAndReset(T &obj) {
2564  T result = std::move(obj);
2565  obj = T();
2566  return result;
2567 }
2568 
2569 LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2570  auto &rewriterImpl = rewriter.getImpl();
2571  LLVM_DEBUG({
2572  rewriterImpl.logger.startLine() << "* Fold {\n";
2573  rewriterImpl.logger.indent();
2574  });
2575 
2576  // Clear pattern state, so that the next pattern application starts with a
2577  // clean slate. (The op/block sets are populated by listener notifications.)
2578  auto cleanup = llvm::make_scope_exit([&]() {
2579  rewriterImpl.patternNewOps.clear();
2580  rewriterImpl.patternModifiedOps.clear();
2581  rewriterImpl.patternInsertedBlocks.clear();
2582  });
2583 
2584  // Upon failure, undo all changes made by the folder.
2585  RewriterState curState = rewriterImpl.getCurrentState();
2586 
2587  // Try to fold the operation.
2588  StringRef opName = op->getName().getStringRef();
2589  SmallVector<Value, 2> replacementValues;
2590  SmallVector<Operation *, 2> newOps;
2591  rewriter.setInsertionPoint(op);
2592  rewriter.startOpModification(op);
2593  if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2594  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2595  rewriter.cancelOpModification(op);
2596  return failure();
2597  }
2598  rewriter.finalizeOpModification(op);
2599 
2600  // An empty list of replacement values indicates that the fold was in-place.
2601  // As the operation changed, a new legalization needs to be attempted.
2602  if (replacementValues.empty())
2603  return legalize(op);
2604 
2605  // Insert a replacement for 'op' with the folded replacement values.
2606  rewriter.replaceOp(op, replacementValues);
2607 
2608  // Recursively legalize any new constant operations.
2609  for (Operation *newOp : newOps) {
2610  if (failed(legalize(newOp))) {
2611  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2612  "failed to legalize generated constant '{0}'",
2613  newOp->getName()));
2614  if (!rewriter.getConfig().allowPatternRollback) {
2615  // Rolling back a folder is like rolling back a pattern.
2616  llvm::report_fatal_error(
2617  "op '" + opName +
2618  "' folder rollback of IR modifications requested");
2619  }
2620  rewriterImpl.resetState(
2621  curState, std::string(op->getName().getStringRef()) + " folder");
2622  return failure();
2623  }
2624  }
2625 
2626  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2627  return success();
2628 }
2629 
2630 /// Report a fatal error indicating that newly produced or modified IR could
2631 /// not be legalized.
2632 static void
2634  const SetVector<Operation *> &newOps,
2635  const SetVector<Operation *> &modifiedOps,
2636  const SetVector<Block *> &insertedBlocks) {
2637  auto newOpNames = llvm::map_range(
2638  newOps, [](Operation *op) { return op->getName().getStringRef(); });
2639  auto modifiedOpNames = llvm::map_range(
2640  modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2641  StringRef detachedBlockStr = "(detached block)";
2642  auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
2643  if (block->getParentOp())
2644  return block->getParentOp()->getName().getStringRef();
2645  return detachedBlockStr;
2646  });
2647  llvm::report_fatal_error(
2648  "pattern '" + pattern.getDebugName() +
2649  "' produced IR that could not be legalized. " + "new ops: {" +
2650  llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
2651  llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
2652  llvm::join(insertedBlockNames, ", ") + "}");
2653 }
2654 
2655 LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2656  auto &rewriterImpl = rewriter.getImpl();
2657  const ConversionConfig &config = rewriter.getConfig();
2658 
2659 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2660  Operation *checkOp;
2661  std::optional<OperationFingerPrint> topLevelFingerPrint;
2662  if (!rewriterImpl.config.allowPatternRollback) {
2663  // The op may be getting erased, so we have to check the parent op.
2664  // (In rare cases, a pattern may even erase the parent op, which will cause
2665  // a crash here. Expensive checks are "best effort".) Skip the check if the
2666  // op does not have a parent op.
2667  if ((checkOp = op->getParentOp())) {
2668  if (!op->getContext()->isMultithreadingEnabled()) {
2669  topLevelFingerPrint = OperationFingerPrint(checkOp);
2670  } else {
2671  // Another thread may be modifying a sibling operation. Therefore, the
2672  // fingerprinting mechanism of the parent op works only in
2673  // single-threaded mode.
2674  LLVM_DEBUG({
2675  rewriterImpl.logger.startLine()
2676  << "WARNING: Multi-threadeding is enabled. Some dialect "
2677  "conversion expensive checks are skipped in multithreading "
2678  "mode!\n";
2679  });
2680  }
2681  }
2682  }
2683 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2684 
2685  // Functor that returns if the given pattern may be applied.
2686  auto canApply = [&](const Pattern &pattern) {
2687  bool canApply = canApplyPattern(op, pattern);
2688  if (canApply && config.listener)
2689  config.listener->notifyPatternBegin(pattern, op);
2690  return canApply;
2691  };
2692 
2693  // Functor that cleans up the rewriter state after a pattern failed to match.
2694  RewriterState curState = rewriterImpl.getCurrentState();
2695  auto onFailure = [&](const Pattern &pattern) {
2696  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2697  if (!rewriterImpl.config.allowPatternRollback) {
2698  // Erase all unresolved materializations.
2699  for (auto op : rewriterImpl.patternMaterializations) {
2700  rewriterImpl.unresolvedMaterializations.erase(op);
2701  op.erase();
2702  }
2703  rewriterImpl.patternMaterializations.clear();
2704 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2705  // Expensive pattern check that can detect API violations.
2706  if (checkOp) {
2707  OperationFingerPrint fingerPrintAfterPattern(checkOp);
2708  if (fingerPrintAfterPattern != *topLevelFingerPrint)
2709  llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2710  "' returned failure but IR did change");
2711  }
2712 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2713  }
2714  rewriterImpl.patternNewOps.clear();
2715  rewriterImpl.patternModifiedOps.clear();
2716  rewriterImpl.patternInsertedBlocks.clear();
2717  LLVM_DEBUG({
2718  logFailure(rewriterImpl.logger, "pattern failed to match");
2719  if (rewriterImpl.config.notifyCallback) {
2721  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2722  << "\" on op:\n"
2723  << *op;
2724  rewriterImpl.config.notifyCallback(diag);
2725  }
2726  });
2727  if (config.listener)
2728  config.listener->notifyPatternEnd(pattern, failure());
2729  rewriterImpl.resetState(curState, pattern.getDebugName());
2730  appliedPatterns.erase(&pattern);
2731  };
2732 
2733  // Functor that performs additional legalization when a pattern is
2734  // successfully applied.
2735  auto onSuccess = [&](const Pattern &pattern) {
2736  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2737  if (!rewriterImpl.config.allowPatternRollback) {
2738  // Eagerly erase unused materializations.
2739  for (auto op : rewriterImpl.patternMaterializations) {
2740  if (op->use_empty()) {
2741  rewriterImpl.unresolvedMaterializations.erase(op);
2742  op.erase();
2743  }
2744  }
2745  rewriterImpl.patternMaterializations.clear();
2746  }
2747  SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2748  SetVector<Operation *> modifiedOps =
2749  moveAndReset(rewriterImpl.patternModifiedOps);
2750  SetVector<Block *> insertedBlocks =
2751  moveAndReset(rewriterImpl.patternInsertedBlocks);
2752  auto result = legalizePatternResult(op, pattern, curState, newOps,
2753  modifiedOps, insertedBlocks);
2754  appliedPatterns.erase(&pattern);
2755  if (failed(result)) {
2756  if (!rewriterImpl.config.allowPatternRollback)
2757  reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
2758  insertedBlocks);
2759  rewriterImpl.resetState(curState, pattern.getDebugName());
2760  }
2761  if (config.listener)
2762  config.listener->notifyPatternEnd(pattern, result);
2763  return result;
2764  };
2765 
2766  // Try to match and rewrite a pattern on this operation.
2767  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2768  onSuccess);
2769 }
2770 
2771 bool OperationLegalizer::canApplyPattern(Operation *op,
2772  const Pattern &pattern) {
2773  LLVM_DEBUG({
2774  auto &os = rewriter.getImpl().logger;
2775  os.getOStream() << "\n";
2776  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2777  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2778  os.getOStream() << ")' {\n";
2779  os.indent();
2780  });
2781 
2782  // Ensure that we don't cycle by not allowing the same pattern to be
2783  // applied twice in the same recursion stack if it is not known to be safe.
2784  if (!pattern.hasBoundedRewriteRecursion() &&
2785  !appliedPatterns.insert(&pattern).second) {
2786  LLVM_DEBUG(
2787  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2788  return false;
2789  }
2790  return true;
2791 }
2792 
2793 LogicalResult OperationLegalizer::legalizePatternResult(
2794  Operation *op, const Pattern &pattern, const RewriterState &curState,
2795  const SetVector<Operation *> &newOps,
2796  const SetVector<Operation *> &modifiedOps,
2797  const SetVector<Block *> &insertedBlocks) {
2798  [[maybe_unused]] auto &impl = rewriter.getImpl();
2799  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2800 
2801 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2802  // Check that the root was either replaced or updated in place.
2803  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2804  auto replacedRoot = [&] {
2805  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2806  };
2807  auto updatedRootInPlace = [&] {
2808  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2809  };
2810  if (!replacedRoot() && !updatedRootInPlace())
2811  llvm::report_fatal_error(
2812  "expected pattern to replace the root operation or modify it in place");
2813 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2814 
2815  // Legalize each of the actions registered during application.
2816  if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
2817  failed(legalizePatternRootUpdates(modifiedOps)) ||
2818  failed(legalizePatternCreatedOperations(newOps))) {
2819  return failure();
2820  }
2821 
2822  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2823  return success();
2824 }
2825 
2826 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2827  Operation *op, const SetVector<Block *> &insertedBlocks,
2828  const SetVector<Operation *> &newOps) {
2829  ConversionPatternRewriterImpl &impl = rewriter.getImpl();
2830  SmallPtrSet<Operation *, 16> alreadyLegalized;
2831 
2832  // If the pattern moved or created any blocks, make sure the types of block
2833  // arguments get legalized.
2834  for (Block *block : insertedBlocks) {
2835  if (impl.erasedBlocks.contains(block))
2836  continue;
2837 
2838  // Only check blocks outside of the current operation.
2839  Operation *parentOp = block->getParentOp();
2840  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2841  continue;
2842 
2843  // If the region of the block has a type converter, try to convert the block
2844  // directly.
2845  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2846  std::optional<TypeConverter::SignatureConversion> conversion =
2847  converter->convertBlockSignature(block);
2848  if (!conversion) {
2849  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2850  "block"));
2851  return failure();
2852  }
2853  impl.applySignatureConversion(block, converter, *conversion);
2854  continue;
2855  }
2856 
2857  // Otherwise, try to legalize the parent operation if it was not generated
2858  // by this pattern. This is because we will attempt to legalize the parent
2859  // operation, and blocks in regions created by this pattern will already be
2860  // legalized later on.
2861  if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
2862  if (failed(legalize(parentOp))) {
2863  LLVM_DEBUG(logFailure(
2864  impl.logger, "operation '{0}'({1}) became illegal after rewrite",
2865  parentOp->getName(), parentOp));
2866  return failure();
2867  }
2868  }
2869  }
2870  return success();
2871 }
2872 
2873 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2874  const SetVector<Operation *> &newOps) {
2875  for (Operation *op : newOps) {
2876  if (failed(legalize(op))) {
2877  LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
2878  "failed to legalize generated operation '{0}'({1})",
2879  op->getName(), op));
2880  return failure();
2881  }
2882  }
2883  return success();
2884 }
2885 
2886 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2887  const SetVector<Operation *> &modifiedOps) {
2888  for (Operation *op : modifiedOps) {
2889  if (failed(legalize(op))) {
2890  LLVM_DEBUG(
2891  logFailure(rewriter.getImpl().logger,
2892  "failed to legalize operation updated in-place '{0}'",
2893  op->getName()));
2894  return failure();
2895  }
2896  }
2897  return success();
2898 }
2899 
2900 //===----------------------------------------------------------------------===//
2901 // Cost Model
2902 //===----------------------------------------------------------------------===//
2903 
2904 void OperationLegalizer::buildLegalizationGraph(
2905  LegalizationPatterns &anyOpLegalizerPatterns,
2906  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2907  // A mapping between an operation and a set of operations that can be used to
2908  // generate it.
2910  // A mapping between an operation and any currently invalid patterns it has.
2912  // A worklist of patterns to consider for legality.
2913  SetVector<const Pattern *> patternWorklist;
2914 
2915  // Build the mapping from operations to the parent ops that may generate them.
2916  applicator.walkAllPatterns([&](const Pattern &pattern) {
2917  std::optional<OperationName> root = pattern.getRootKind();
2918 
2919  // If the pattern has no specific root, we can't analyze the relationship
2920  // between the root op and generated operations. Given that, add all such
2921  // patterns to the legalization set.
2922  if (!root) {
2923  anyOpLegalizerPatterns.push_back(&pattern);
2924  return;
2925  }
2926 
2927  // Skip operations that are always known to be legal.
2928  if (target.getOpAction(*root) == LegalizationAction::Legal)
2929  return;
2930 
2931  // Add this pattern to the invalid set for the root op and record this root
2932  // as a parent for any generated operations.
2933  invalidPatterns[*root].insert(&pattern);
2934  for (auto op : pattern.getGeneratedOps())
2935  parentOps[op].insert(*root);
2936 
2937  // Add this pattern to the worklist.
2938  patternWorklist.insert(&pattern);
2939  });
2940 
2941  // If there are any patterns that don't have a specific root kind, we can't
2942  // make direct assumptions about what operations will never be legalized.
2943  // Note: Technically we could, but it would require an analysis that may
2944  // recurse into itself. It would be better to perform this kind of filtering
2945  // at a higher level than here anyways.
2946  if (!anyOpLegalizerPatterns.empty()) {
2947  for (const Pattern *pattern : patternWorklist)
2948  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2949  return;
2950  }
2951 
2952  while (!patternWorklist.empty()) {
2953  auto *pattern = patternWorklist.pop_back_val();
2954 
2955  // Check to see if any of the generated operations are invalid.
2956  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2957  std::optional<LegalizationAction> action = target.getOpAction(op);
2958  return !legalizerPatterns.count(op) &&
2959  (!action || action == LegalizationAction::Illegal);
2960  }))
2961  continue;
2962 
2963  // Otherwise, if all of the generated operation are valid, this op is now
2964  // legal so add all of the child patterns to the worklist.
2965  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2966  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2967 
2968  // Add any invalid patterns of the parent operations to see if they have now
2969  // become legal.
2970  for (auto op : parentOps[*pattern->getRootKind()])
2971  patternWorklist.set_union(invalidPatterns[op]);
2972  }
2973 }
2974 
2975 void OperationLegalizer::computeLegalizationGraphBenefit(
2976  LegalizationPatterns &anyOpLegalizerPatterns,
2977  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2978  // The smallest pattern depth, when legalizing an operation.
2979  DenseMap<OperationName, unsigned> minOpPatternDepth;
2980 
2981  // For each operation that is transitively legal, compute a cost for it.
2982  for (auto &opIt : legalizerPatterns)
2983  if (!minOpPatternDepth.count(opIt.first))
2984  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2985  legalizerPatterns);
2986 
2987  // Apply the cost model to the patterns that can match any operation. Those
2988  // with a specific operation type are already resolved when computing the op
2989  // legalization depth.
2990  if (!anyOpLegalizerPatterns.empty())
2991  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2992  legalizerPatterns);
2993 
2994  // Apply a cost model to the pattern applicator. We order patterns first by
2995  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2996  // decreasing benefit.
2997  applicator.applyCostModel([&](const Pattern &pattern) {
2998  ArrayRef<const Pattern *> orderedPatternList;
2999  if (std::optional<OperationName> rootName = pattern.getRootKind())
3000  orderedPatternList = legalizerPatterns[*rootName];
3001  else
3002  orderedPatternList = anyOpLegalizerPatterns;
3003 
3004  // If the pattern is not found, then it was removed and cannot be matched.
3005  auto *it = llvm::find(orderedPatternList, &pattern);
3006  if (it == orderedPatternList.end())
3008 
3009  // Patterns found earlier in the list have higher benefit.
3010  return PatternBenefit(std::distance(it, orderedPatternList.end()));
3011  });
3012 }
3013 
3014 unsigned OperationLegalizer::computeOpLegalizationDepth(
3015  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
3016  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
3017  // Check for existing depth.
3018  auto depthIt = minOpPatternDepth.find(op);
3019  if (depthIt != minOpPatternDepth.end())
3020  return depthIt->second;
3021 
3022  // If a mapping for this operation does not exist, then this operation
3023  // is always legal. Return 0 as the depth for a directly legal operation.
3024  auto opPatternsIt = legalizerPatterns.find(op);
3025  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3026  return 0u;
3027 
3028  // Record this initial depth in case we encounter this op again when
3029  // recursively computing the depth.
3030  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3031 
3032  // Apply the cost model to the operation patterns, and update the minimum
3033  // depth.
3034  unsigned minDepth = applyCostModelToPatterns(
3035  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3036  minOpPatternDepth[op] = minDepth;
3037  return minDepth;
3038 }
3039 
3040 unsigned OperationLegalizer::applyCostModelToPatterns(
3041  LegalizationPatterns &patterns,
3042  DenseMap<OperationName, unsigned> &minOpPatternDepth,
3043  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
3044  unsigned minDepth = std::numeric_limits<unsigned>::max();
3045 
3046  // Compute the depth for each pattern within the set.
3047  SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3048  patternsByDepth.reserve(patterns.size());
3049  for (const Pattern *pattern : patterns) {
3050  unsigned depth = 1;
3051  for (auto generatedOp : pattern->getGeneratedOps()) {
3052  unsigned generatedOpDepth = computeOpLegalizationDepth(
3053  generatedOp, minOpPatternDepth, legalizerPatterns);
3054  depth = std::max(depth, generatedOpDepth + 1);
3055  }
3056  patternsByDepth.emplace_back(pattern, depth);
3057 
3058  // Update the minimum depth of the pattern list.
3059  minDepth = std::min(minDepth, depth);
3060  }
3061 
3062  // If the operation only has one legalization pattern, there is no need to
3063  // sort them.
3064  if (patternsByDepth.size() == 1)
3065  return minDepth;
3066 
3067  // Sort the patterns by those likely to be the most beneficial.
3068  llvm::stable_sort(patternsByDepth,
3069  [](const std::pair<const Pattern *, unsigned> &lhs,
3070  const std::pair<const Pattern *, unsigned> &rhs) {
3071  // First sort by the smaller pattern legalization
3072  // depth.
3073  if (lhs.second != rhs.second)
3074  return lhs.second < rhs.second;
3075 
3076  // Then sort by the larger pattern benefit.
3077  auto lhsBenefit = lhs.first->getBenefit();
3078  auto rhsBenefit = rhs.first->getBenefit();
3079  return lhsBenefit > rhsBenefit;
3080  });
3081 
3082  // Update the legalization pattern to use the new sorted list.
3083  patterns.clear();
3084  for (auto &patternIt : patternsByDepth)
3085  patterns.push_back(patternIt.first);
3086  return minDepth;
3087 }
3088 
3089 //===----------------------------------------------------------------------===//
3090 // Reconcile Unrealized Casts
3091 //===----------------------------------------------------------------------===//
3092 
3093 /// Try to reconcile all given UnrealizedConversionCastOps and store the
3094 /// left-over ops in `remainingCastOps` (if provided). See documentation in
3095 /// DialectConversion.h for more details.
3096 /// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3097 /// algorithm may visit an operand (or user) which is a cast op, but will not
3098 /// try to reconcile it if not in the filtered set.
3099 template <typename RangeT>
3101  RangeT castOps,
3102  function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3103  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3104  // A worklist of cast ops to process.
3105  SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3106 
3107  // Helper function that return the unrealized_conversion_cast op that
3108  // defines all inputs of the given op (in the same order). Return "nullptr"
3109  // if there is no such op.
3110  auto getInputCast =
3111  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3112  if (castOp.getInputs().empty())
3113  return {};
3114  auto inputCastOp =
3115  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3116  if (!inputCastOp)
3117  return {};
3118  if (inputCastOp.getOutputs() != castOp.getInputs())
3119  return {};
3120  return inputCastOp;
3121  };
3122 
3123  // Process ops in the worklist bottom-to-top.
3124  while (!worklist.empty()) {
3125  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3126 
3127  // Traverse the chain of input cast ops to see if an op with the same
3128  // input types can be found.
3129  UnrealizedConversionCastOp nextCast = castOp;
3130  while (nextCast) {
3131  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3132  if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3133  return v.getDefiningOp() == castOp;
3134  })) {
3135  // Ran into a cycle.
3136  break;
3137  }
3138 
3139  // Found a cast where the input types match the output types of the
3140  // matched op. We can directly use those inputs.
3141  castOp.replaceAllUsesWith(nextCast.getInputs());
3142  break;
3143  }
3144  nextCast = getInputCast(nextCast);
3145  }
3146  }
3147 
3148  // A set of all alive cast ops. I.e., ops whose results are (transitively)
3149  // used by an op that is not a cast op.
3150  DenseSet<Operation *> liveOps;
3151 
3152  // Helper function that marks the given op and transitively reachable input
3153  // cast ops as alive.
3154  auto markOpLive = [&](Operation *rootOp) {
3155  SmallVector<Operation *> worklist;
3156  worklist.push_back(rootOp);
3157  while (!worklist.empty()) {
3158  Operation *op = worklist.pop_back_val();
3159  if (liveOps.insert(op).second) {
3160  // Successfully inserted: process reachable input cast ops.
3161  for (Value v : op->getOperands())
3162  if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3163  if (isCastOpOfInterestFn(castOp))
3164  worklist.push_back(castOp);
3165  }
3166  }
3167  };
3168 
3169  // Find all alive cast ops.
3170  for (UnrealizedConversionCastOp op : castOps) {
3171  // The op may have been marked live already as being an operand of another
3172  // live cast op.
3173  if (liveOps.contains(op.getOperation()))
3174  continue;
3175  // If any of the users is not a cast op, mark the current op (and its
3176  // input ops) as live.
3177  if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3178  auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3179  return !castOp || !isCastOpOfInterestFn(castOp);
3180  }))
3181  markOpLive(op);
3182  }
3183 
3184  // Erase all dead cast ops.
3185  for (UnrealizedConversionCastOp op : castOps) {
3186  if (liveOps.contains(op)) {
3187  // Op is alive and was not erased. Add it to the remaining cast ops.
3188  if (remainingCastOps)
3189  remainingCastOps->push_back(op);
3190  continue;
3191  }
3192 
3193  // Op is dead. Erase it.
3194  op->dropAllUses();
3195  op->erase();
3196  }
3197 }
3198 
3201  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3202  // Set of all cast ops for faster lookups.
3204  for (UnrealizedConversionCastOp op : castOps)
3205  castOpSet.insert(op);
3206  reconcileUnrealizedCasts(castOpSet, remainingCastOps);
3207 }
3208 
3210  const DenseSet<UnrealizedConversionCastOp> &castOps,
3211  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3213  llvm::make_range(castOps.begin(), castOps.end()),
3214  [&](UnrealizedConversionCastOp castOp) {
3215  return castOps.contains(castOp);
3216  },
3217  remainingCastOps);
3218 }
3219 
3220 namespace mlir {
3223  &castOps,
3224  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3226  castOps.keys(),
3227  [&](UnrealizedConversionCastOp castOp) {
3228  return castOps.contains(castOp);
3229  },
3230  remainingCastOps);
3231 }
3232 } // namespace mlir
3233 
3234 //===----------------------------------------------------------------------===//
3235 // OperationConverter
3236 //===----------------------------------------------------------------------===//
3237 
3238 namespace {
3239 enum OpConversionMode {
3240  /// In this mode, the conversion will ignore failed conversions to allow
3241  /// illegal operations to co-exist in the IR.
3242  Partial,
3243 
3244  /// In this mode, all operations must be legal for the given target for the
3245  /// conversion to succeed.
3246  Full,
3247 
3248  /// In this mode, operations are analyzed for legality. No actual rewrites are
3249  /// applied to the operations on success.
3250  Analysis,
3251 };
3252 } // namespace
3253 
3254 namespace mlir {
3255 // This class converts operations to a given conversion target via a set of
3256 // rewrite patterns. The conversion behaves differently depending on the
3257 // conversion mode.
3259  explicit OperationConverter(MLIRContext *ctx, const ConversionTarget &target,
3261  const ConversionConfig &config,
3262  OpConversionMode mode)
3263  : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
3264  mode(mode) {}
3265 
3266  /// Converts the given operations to the conversion target.
3267  LogicalResult convertOperations(ArrayRef<Operation *> ops);
3268 
3269 private:
3270  /// Converts an operation with the given rewriter.
3271  LogicalResult convert(Operation *op);
3272 
3273  /// The rewriter to use when converting operations.
3274  ConversionPatternRewriter rewriter;
3275 
3276  /// The legalizer to use when converting operations.
3277  OperationLegalizer opLegalizer;
3278 
3279  /// The conversion mode to use when legalizing operations.
3280  OpConversionMode mode;
3281 };
3282 } // namespace mlir
3283 
3284 LogicalResult OperationConverter::convert(Operation *op) {
3285  const ConversionConfig &config = rewriter.getConfig();
3286 
3287  // Legalize the given operation.
3288  if (failed(opLegalizer.legalize(op))) {
3289  // Handle the case of a failed conversion for each of the different modes.
3290  // Full conversions expect all operations to be converted.
3291  if (mode == OpConversionMode::Full)
3292  return op->emitError()
3293  << "failed to legalize operation '" << op->getName() << "'";
3294  // Partial conversions allow conversions to fail iff the operation was not
3295  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
3296  // set, non-legalizable ops are added to that set.
3297  if (mode == OpConversionMode::Partial) {
3298  if (opLegalizer.isIllegal(op))
3299  return op->emitError()
3300  << "failed to legalize operation '" << op->getName()
3301  << "' that was explicitly marked illegal";
3302  if (config.unlegalizedOps)
3303  config.unlegalizedOps->insert(op);
3304  }
3305  } else if (mode == OpConversionMode::Analysis) {
3306  // Analysis conversions don't fail if any operations fail to legalize,
3307  // they are only interested in the operations that were successfully
3308  // legalized.
3309  if (config.legalizableOps)
3310  config.legalizableOps->insert(op);
3311  }
3312  return success();
3313 }
3314 
3315 static LogicalResult
3317  UnrealizedConversionCastOp op,
3318  const UnresolvedMaterializationInfo &info) {
3319  assert(!op.use_empty() &&
3320  "expected that dead materializations have already been DCE'd");
3321  Operation::operand_range inputOperands = op.getOperands();
3322 
3323  // Try to materialize the conversion.
3324  if (const TypeConverter *converter = info.getConverter()) {
3325  rewriter.setInsertionPoint(op);
3326  SmallVector<Value> newMaterialization;
3327  switch (info.getMaterializationKind()) {
3328  case MaterializationKind::Target:
3329  newMaterialization = converter->materializeTargetConversion(
3330  rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3331  info.getOriginalType());
3332  break;
3333  case MaterializationKind::Source:
3334  assert(op->getNumResults() == 1 && "expected single result");
3335  Value sourceMat = converter->materializeSourceConversion(
3336  rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3337  if (sourceMat)
3338  newMaterialization.push_back(sourceMat);
3339  break;
3340  }
3341  if (!newMaterialization.empty()) {
3342 #ifndef NDEBUG
3343  ValueRange newMaterializationRange(newMaterialization);
3344  assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
3345  "materialization callback produced value of incorrect type");
3346 #endif // NDEBUG
3347  rewriter.replaceOp(op, newMaterialization);
3348  return success();
3349  }
3350  }
3351 
3352  InFlightDiagnostic diag = op->emitError()
3353  << "failed to legalize unresolved materialization "
3354  "from ("
3355  << inputOperands.getTypes() << ") to ("
3356  << op.getResultTypes()
3357  << ") that remained live after conversion";
3358  diag.attachNote(op->getUsers().begin()->getLoc())
3359  << "see existing live user here: " << *op->getUsers().begin();
3360  return failure();
3361 }
3362 
3364  const ConversionTarget &target = opLegalizer.getTarget();
3365 
3366  // Compute the set of operations and blocks to convert.
3367  SmallVector<Operation *> toConvert;
3368  for (auto *op : ops) {
3370  [&](Operation *op) {
3371  toConvert.push_back(op);
3372  // Don't check this operation's children for conversion if the
3373  // operation is recursively legal.
3374  auto legalityInfo = target.isLegal(op);
3375  if (legalityInfo && legalityInfo->isRecursivelyLegal)
3376  return WalkResult::skip();
3377  return WalkResult::advance();
3378  });
3379  }
3380 
3381  // Convert each operation and discard rewrites on failure.
3382  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3383 
3384  for (auto *op : toConvert) {
3385  if (failed(convert(op))) {
3386  // Dialect conversion failed.
3387  if (rewriterImpl.config.allowPatternRollback) {
3388  // Rollback is allowed: restore the original IR.
3389  rewriterImpl.undoRewrites();
3390  } else {
3391  // Rollback is not allowed: apply all modifications that have been
3392  // performed so far.
3393  rewriterImpl.applyRewrites();
3394  }
3395  return failure();
3396  }
3397  }
3398 
3399  // After a successful conversion, apply rewrites.
3400  rewriterImpl.applyRewrites();
3401 
3402  // Reconcile all UnrealizedConversionCastOps that were inserted by the
3403  // dialect conversion frameworks. (Not the ones that were inserted by
3404  // patterns.)
3406  &materializations = rewriterImpl.unresolvedMaterializations;
3407  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
3408  reconcileUnrealizedCasts(materializations, &remainingCastOps);
3409 
3410  // Drop markers.
3411  for (UnrealizedConversionCastOp castOp : remainingCastOps)
3412  castOp->removeAttr(kPureTypeConversionMarker);
3413 
3414  // Try to legalize all unresolved materializations.
3415  if (rewriter.getConfig().buildMaterializations) {
3416  // Use a new rewriter, so the modifications are not tracked for rollback
3417  // purposes etc.
3418  IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
3419  rewriter.getConfig().listener);
3420  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3421  auto it = materializations.find(castOp);
3422  assert(it != materializations.end() && "inconsistent state");
3423  if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
3424  it->second)))
3425  return failure();
3426  }
3427  }
3428 
3429  return success();
3430 }
3431 
3432 //===----------------------------------------------------------------------===//
3433 // Type Conversion
3434 //===----------------------------------------------------------------------===//
3435 
3437  ArrayRef<Type> types) {
3438  assert(!types.empty() && "expected valid types");
3439  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3440  addInputs(types);
3441 }
3442 
3444  assert(!types.empty() &&
3445  "1->0 type remappings don't need to be added explicitly");
3446  argTypes.append(types.begin(), types.end());
3447 }
3448 
3449 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3450  unsigned newInputNo,
3451  unsigned newInputCount) {
3452  assert(!remappedInputs[origInputNo] && "input has already been remapped");
3453  assert(newInputCount != 0 && "expected valid input count");
3454  remappedInputs[origInputNo] =
3455  InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
3456 }
3457 
3459  unsigned origInputNo, ArrayRef<Value> replacements) {
3460  assert(!remappedInputs[origInputNo] && "input has already been remapped");
3461  remappedInputs[origInputNo] = InputMapping{
3462  origInputNo, /*size=*/0,
3463  SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3464 }
3465 
3466 /// Internal implementation of the type conversion.
3467 /// This is used with either a Type or a Value as the first argument.
3468 /// - we can cache the context-free conversions until the last registered
3469 /// context-aware conversion.
3470 /// - we can't cache the result of type conversion happening after context-aware
3471 /// conversions, because the type converter may return different results for the
3472 /// same input type.
3473 LogicalResult
3474 TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3475  SmallVectorImpl<Type> &results) const {
3476  assert(typeOrValue && "expected non-null type");
3477  Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3478  : cast<Type>(typeOrValue);
3479  {
3480  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3481  std::defer_lock);
3483  cacheReadLock.lock();
3484  auto existingIt = cachedDirectConversions.find(t);
3485  if (existingIt != cachedDirectConversions.end()) {
3486  if (existingIt->second)
3487  results.push_back(existingIt->second);
3488  return success(existingIt->second != nullptr);
3489  }
3490  auto multiIt = cachedMultiConversions.find(t);
3491  if (multiIt != cachedMultiConversions.end()) {
3492  results.append(multiIt->second.begin(), multiIt->second.end());
3493  return success();
3494  }
3495  }
3496  // Walk the added converters in reverse order to apply the most recently
3497  // registered first.
3498  size_t currentCount = results.size();
3499 
3500  // We can cache the context-free conversions until the last registered
3501  // context-aware conversion. But only if we're processing a Value right now.
3502  auto isCacheable = [&](int index) {
3503  int numberOfConversionsUntilContextAware =
3504  conversions.size() - 1 - contextAwareTypeConversionsIndex;
3505  return index < numberOfConversionsUntilContextAware;
3506  };
3507 
3508  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3509  std::defer_lock);
3510 
3511  for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3512  const ConversionCallbackFn &converter = indexedConverter.value();
3513  std::optional<LogicalResult> result = converter(typeOrValue, results);
3514  if (!result) {
3515  assert(results.size() == currentCount &&
3516  "failed type conversion should not change results");
3517  continue;
3518  }
3519  if (!isCacheable(indexedConverter.index()))
3520  return success();
3522  cacheWriteLock.lock();
3523  if (!succeeded(*result)) {
3524  assert(results.size() == currentCount &&
3525  "failed type conversion should not change results");
3526  cachedDirectConversions.try_emplace(t, nullptr);
3527  return failure();
3528  }
3529  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3530  if (newTypes.size() == 1)
3531  cachedDirectConversions.try_emplace(t, newTypes.front());
3532  else
3533  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3534  return success();
3535  }
3536  return failure();
3537 }
3538 
3540  SmallVectorImpl<Type> &results) const {
3541  return convertTypeImpl(t, results);
3542 }
3543 
3545  SmallVectorImpl<Type> &results) const {
3546  return convertTypeImpl(v, results);
3547 }
3548 
3550  // Use the multi-type result version to convert the type.
3551  SmallVector<Type, 1> results;
3552  if (failed(convertType(t, results)))
3553  return nullptr;
3554 
3555  // Check to ensure that only one type was produced.
3556  return results.size() == 1 ? results.front() : nullptr;
3557 }
3558 
3560  // Use the multi-type result version to convert the type.
3561  SmallVector<Type, 1> results;
3562  if (failed(convertType(v, results)))
3563  return nullptr;
3564 
3565  // Check to ensure that only one type was produced.
3566  return results.size() == 1 ? results.front() : nullptr;
3567 }
3568 
3569 LogicalResult
3571  SmallVectorImpl<Type> &results) const {
3572  for (Type type : types)
3573  if (failed(convertType(type, results)))
3574  return failure();
3575  return success();
3576 }
3577 
3578 LogicalResult
3580  SmallVectorImpl<Type> &results) const {
3581  for (Value value : values)
3582  if (failed(convertType(value, results)))
3583  return failure();
3584  return success();
3585 }
3586 
3587 bool TypeConverter::isLegal(Type type) const {
3588  return convertType(type) == type;
3589 }
3590 
3591 bool TypeConverter::isLegal(Value value) const {
3592  return convertType(value) == value.getType();
3593 }
3594 
3596  return isLegal(op->getOperands()) && isLegal(op->getResults());
3597 }
3598 
3599 bool TypeConverter::isLegal(Region *region) const {
3600  return llvm::all_of(
3601  *region, [this](Block &block) { return isLegal(block.getArguments()); });
3602 }
3603 
3604 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3605  if (!isLegal(ty.getInputs()))
3606  return false;
3607  if (!isLegal(ty.getResults()))
3608  return false;
3609  return true;
3610 }
3611 
3612 LogicalResult
3614  SignatureConversion &result) const {
3615  // Try to convert the given input type.
3616  SmallVector<Type, 1> convertedTypes;
3617  if (failed(convertType(type, convertedTypes)))
3618  return failure();
3619 
3620  // If this argument is being dropped, there is nothing left to do.
3621  if (convertedTypes.empty())
3622  return success();
3623 
3624  // Otherwise, add the new inputs.
3625  result.addInputs(inputNo, convertedTypes);
3626  return success();
3627 }
3628 LogicalResult
3630  SignatureConversion &result,
3631  unsigned origInputOffset) const {
3632  for (unsigned i = 0, e = types.size(); i != e; ++i)
3633  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3634  return failure();
3635  return success();
3636 }
3637 LogicalResult
3639  SignatureConversion &result) const {
3640  // Try to convert the given input type.
3641  SmallVector<Type, 1> convertedTypes;
3642  if (failed(convertType(value, convertedTypes)))
3643  return failure();
3644 
3645  // If this argument is being dropped, there is nothing left to do.
3646  if (convertedTypes.empty())
3647  return success();
3648 
3649  // Otherwise, add the new inputs.
3650  result.addInputs(inputNo, convertedTypes);
3651  return success();
3652 }
3653 LogicalResult
3655  SignatureConversion &result,
3656  unsigned origInputOffset) const {
3657  for (unsigned i = 0, e = values.size(); i != e; ++i)
3658  if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
3659  return failure();
3660  return success();
3661 }
3662 
3664  Location loc, Type resultType,
3665  ValueRange inputs) const {
3666  for (const SourceMaterializationCallbackFn &fn :
3667  llvm::reverse(sourceMaterializations))
3668  if (Value result = fn(builder, resultType, inputs, loc))
3669  return result;
3670  return nullptr;
3671 }
3672 
3674  Location loc, Type resultType,
3675  ValueRange inputs,
3676  Type originalType) const {
3678  builder, loc, TypeRange(resultType), inputs, originalType);
3679  if (result.empty())
3680  return nullptr;
3681  assert(result.size() == 1 && "expected single result");
3682  return result.front();
3683 }
3684 
3686  OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3687  Type originalType) const {
3688  for (const TargetMaterializationCallbackFn &fn :
3689  llvm::reverse(targetMaterializations)) {
3690  SmallVector<Value> result =
3691  fn(builder, resultTypes, inputs, loc, originalType);
3692  if (result.empty())
3693  continue;
3694  assert(TypeRange(ValueRange(result)) == resultTypes &&
3695  "callback produced incorrect number of values or values with "
3696  "incorrect types");
3697  return result;
3698  }
3699  return {};
3700 }
3701 
3702 std::optional<TypeConverter::SignatureConversion>
3704  SignatureConversion conversion(block->getNumArguments());
3705  if (failed(convertSignatureArgs(block->getArguments(), conversion)))
3706  return std::nullopt;
3707  return conversion;
3708 }
3709 
3710 //===----------------------------------------------------------------------===//
3711 // Type attribute conversion
3712 //===----------------------------------------------------------------------===//
3715  return AttributeConversionResult(attr, resultTag);
3716 }
3717 
3720  return AttributeConversionResult(nullptr, naTag);
3721 }
3722 
3725  return AttributeConversionResult(nullptr, abortTag);
3726 }
3727 
3729  return impl.getInt() == resultTag;
3730 }
3731 
3733  return impl.getInt() == naTag;
3734 }
3735 
3737  return impl.getInt() == abortTag;
3738 }
3739 
3741  assert(hasResult() && "Cannot get result from N/A or abort");
3742  return impl.getPointer();
3743 }
3744 
3745 std::optional<Attribute>
3747  for (const TypeAttributeConversionCallbackFn &fn :
3748  llvm::reverse(typeAttributeConversions)) {
3749  AttributeConversionResult res = fn(type, attr);
3750  if (res.hasResult())
3751  return res.getResult();
3752  if (res.isAbort())
3753  return std::nullopt;
3754  }
3755  return std::nullopt;
3756 }
3757 
3758 //===----------------------------------------------------------------------===//
3759 // FunctionOpInterfaceSignatureConversion
3760 //===----------------------------------------------------------------------===//
3761 
3762 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3763  const TypeConverter &typeConverter,
3764  ConversionPatternRewriter &rewriter) {
3765  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3766  if (!type)
3767  return failure();
3768 
3769  // Convert the original function types.
3770  TypeConverter::SignatureConversion result(type.getNumInputs());
3771  SmallVector<Type, 1> newResults;
3772  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3773  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3774  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3775  typeConverter, &result)))
3776  return failure();
3777 
3778  // Update the function signature in-place.
3779  auto newType = FunctionType::get(rewriter.getContext(),
3780  result.getConvertedTypes(), newResults);
3781 
3782  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3783 
3784  return success();
3785 }
3786 
3787 /// Create a default conversion pattern that rewrites the type signature of a
3788 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3789 /// represent their type.
3790 namespace {
3791 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3792  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3793  MLIRContext *ctx,
3794  const TypeConverter &converter,
3795  PatternBenefit benefit)
3796  : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3797 
3798  LogicalResult
3799  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3800  ConversionPatternRewriter &rewriter) const override {
3801  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3802  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3803  }
3804 };
3805 
3806 struct AnyFunctionOpInterfaceSignatureConversion
3807  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3809 
3810  LogicalResult
3811  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3812  ConversionPatternRewriter &rewriter) const override {
3813  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3814  }
3815 };
3816 } // namespace
3817 
3818 FailureOr<Operation *>
3820  const TypeConverter &converter,
3821  ConversionPatternRewriter &rewriter) {
3822  assert(op && "Invalid op");
3823  Location loc = op->getLoc();
3824  if (converter.isLegal(op))
3825  return rewriter.notifyMatchFailure(loc, "op already legal");
3826 
3827  OperationState newOp(loc, op->getName());
3828  newOp.addOperands(operands);
3829 
3830  SmallVector<Type> newResultTypes;
3831  if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
3832  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3833 
3834  newOp.addTypes(newResultTypes);
3835  newOp.addAttributes(op->getAttrs());
3836  return rewriter.create(newOp);
3837 }
3838 
3840  StringRef functionLikeOpName, RewritePatternSet &patterns,
3841  const TypeConverter &converter, PatternBenefit benefit) {
3842  patterns.add<FunctionOpInterfaceSignatureConversion>(
3843  functionLikeOpName, patterns.getContext(), converter, benefit);
3844 }
3845 
3847  RewritePatternSet &patterns, const TypeConverter &converter,
3848  PatternBenefit benefit) {
3849  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3850  converter, patterns.getContext(), benefit);
3851 }
3852 
3853 //===----------------------------------------------------------------------===//
3854 // ConversionTarget
3855 //===----------------------------------------------------------------------===//
3856 
3858  LegalizationAction action) {
3859  legalOperations[op].action = action;
3860 }
3861 
3863  LegalizationAction action) {
3864  for (StringRef dialect : dialectNames)
3865  legalDialects[dialect] = action;
3866 }
3867 
3869  -> std::optional<LegalizationAction> {
3870  std::optional<LegalizationInfo> info = getOpInfo(op);
3871  return info ? info->action : std::optional<LegalizationAction>();
3872 }
3873 
3875  -> std::optional<LegalOpDetails> {
3876  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3877  if (!info)
3878  return std::nullopt;
3879 
3880  // Returns true if this operation instance is known to be legal.
3881  auto isOpLegal = [&] {
3882  // Handle dynamic legality either with the provided legality function.
3883  if (info->action == LegalizationAction::Dynamic) {
3884  std::optional<bool> result = info->legalityFn(op);
3885  if (result)
3886  return *result;
3887  }
3888 
3889  // Otherwise, the operation is only legal if it was marked 'Legal'.
3890  return info->action == LegalizationAction::Legal;
3891  };
3892  if (!isOpLegal())
3893  return std::nullopt;
3894 
3895  // This operation is legal, compute any additional legality information.
3896  LegalOpDetails legalityDetails;
3897  if (info->isRecursivelyLegal) {
3898  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3899  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3900  legalityDetails.isRecursivelyLegal =
3901  legalityFnIt->second(op).value_or(true);
3902  } else {
3903  legalityDetails.isRecursivelyLegal = true;
3904  }
3905  }
3906  return legalityDetails;
3907 }
3908 
3910  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3911  if (!info)
3912  return false;
3913 
3914  if (info->action == LegalizationAction::Dynamic) {
3915  std::optional<bool> result = info->legalityFn(op);
3916  if (!result)
3917  return false;
3918 
3919  return !(*result);
3920  }
3921 
3922  return info->action == LegalizationAction::Illegal;
3923 }
3924 
3928  if (!oldCallback)
3929  return newCallback;
3930 
3931  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3932  Operation *op) -> std::optional<bool> {
3933  if (std::optional<bool> result = newCl(op))
3934  return *result;
3935 
3936  return oldCl(op);
3937  };
3938  return chain;
3939 }
3940 
3941 void ConversionTarget::setLegalityCallback(
3942  OperationName name, const DynamicLegalityCallbackFn &callback) {
3943  assert(callback && "expected valid legality callback");
3944  auto *infoIt = legalOperations.find(name);
3945  assert(infoIt != legalOperations.end() &&
3946  infoIt->second.action == LegalizationAction::Dynamic &&
3947  "expected operation to already be marked as dynamically legal");
3948  infoIt->second.legalityFn =
3949  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3950 }
3951 
3953  OperationName name, const DynamicLegalityCallbackFn &callback) {
3954  auto *infoIt = legalOperations.find(name);
3955  assert(infoIt != legalOperations.end() &&
3956  infoIt->second.action != LegalizationAction::Illegal &&
3957  "expected operation to already be marked as legal");
3958  infoIt->second.isRecursivelyLegal = true;
3959  if (callback)
3960  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3961  std::move(opRecursiveLegalityFns[name]), callback);
3962  else
3963  opRecursiveLegalityFns.erase(name);
3964 }
3965 
3966 void ConversionTarget::setLegalityCallback(
3967  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3968  assert(callback && "expected valid legality callback");
3969  for (StringRef dialect : dialects)
3970  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3971  std::move(dialectLegalityFns[dialect]), callback);
3972 }
3973 
3974 void ConversionTarget::setLegalityCallback(
3975  const DynamicLegalityCallbackFn &callback) {
3976  assert(callback && "expected valid legality callback");
3977  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3978 }
3979 
3980 auto ConversionTarget::getOpInfo(OperationName op) const
3981  -> std::optional<LegalizationInfo> {
3982  // Check for info for this specific operation.
3983  const auto *it = legalOperations.find(op);
3984  if (it != legalOperations.end())
3985  return it->second;
3986  // Check for info for the parent dialect.
3987  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3988  if (dialectIt != legalDialects.end()) {
3989  DynamicLegalityCallbackFn callback;
3990  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3991  if (dialectFn != dialectLegalityFns.end())
3992  callback = dialectFn->second;
3993  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3994  callback};
3995  }
3996  // Otherwise, check if we mark unknown operations as dynamic.
3997  if (unknownLegalityFn)
3998  return LegalizationInfo{LegalizationAction::Dynamic,
3999  /*isRecursivelyLegal=*/false, unknownLegalityFn};
4000  return std::nullopt;
4001 }
4002 
4003 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4004 //===----------------------------------------------------------------------===//
4005 // PDL Configuration
4006 //===----------------------------------------------------------------------===//
4007 
4009  auto &rewriterImpl =
4010  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4011  rewriterImpl.currentTypeConverter = getTypeConverter();
4012 }
4013 
4015  auto &rewriterImpl =
4016  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4017  rewriterImpl.currentTypeConverter = nullptr;
4018 }
4019 
4020 /// Remap the given value using the rewriter and the type converter in the
4021 /// provided config.
4022 static FailureOr<SmallVector<Value>>
4024  SmallVector<Value> mappedValues;
4025  if (failed(rewriter.getRemappedValues(values, mappedValues)))
4026  return failure();
4027  return std::move(mappedValues);
4028 }
4029 
4031  patterns.getPDLPatterns().registerRewriteFunction(
4032  "convertValue",
4033  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4034  auto results = pdllConvertValues(
4035  static_cast<ConversionPatternRewriter &>(rewriter), value);
4036  if (failed(results))
4037  return failure();
4038  return results->front();
4039  });
4040  patterns.getPDLPatterns().registerRewriteFunction(
4041  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
4042  return pdllConvertValues(
4043  static_cast<ConversionPatternRewriter &>(rewriter), values);
4044  });
4045  patterns.getPDLPatterns().registerRewriteFunction(
4046  "convertType",
4047  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4048  auto &rewriterImpl =
4049  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4050  if (const TypeConverter *converter =
4051  rewriterImpl.currentTypeConverter) {
4052  if (Type newType = converter->convertType(type))
4053  return newType;
4054  return failure();
4055  }
4056  return type;
4057  });
4058  patterns.getPDLPatterns().registerRewriteFunction(
4059  "convertTypes",
4060  [](PatternRewriter &rewriter,
4061  TypeRange types) -> FailureOr<SmallVector<Type>> {
4062  auto &rewriterImpl =
4063  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4064  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
4065  if (!converter)
4066  return SmallVector<Type>(types);
4067 
4068  SmallVector<Type> remappedTypes;
4069  if (failed(converter->convertTypes(types, remappedTypes)))
4070  return failure();
4071  return std::move(remappedTypes);
4072  });
4073 }
4074 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
4075 
4076 //===----------------------------------------------------------------------===//
4077 // Op Conversion Entry Points
4078 //===----------------------------------------------------------------------===//
4079 
4080 /// This is the type of Action that is dispatched when a conversion is applied.
4082  : public tracing::ActionImpl<ApplyConversionAction> {
4083 public:
4086  static constexpr StringLiteral tag = "apply-conversion";
4087  static constexpr StringLiteral desc =
4088  "Encapsulate the application of a dialect conversion";
4089 
4090  void print(raw_ostream &os) const override { os << tag; }
4091 };
4092 
4093 static LogicalResult applyConversion(ArrayRef<Operation *> ops,
4094  const ConversionTarget &target,
4097  OpConversionMode mode) {
4098  if (ops.empty())
4099  return success();
4100  MLIRContext *ctx = ops.front()->getContext();
4101  LogicalResult status = success();
4102  SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4104  [&] {
4105  OperationConverter opConverter(ops.front()->getContext(), target,
4106  patterns, config, mode);
4107  status = opConverter.convertOperations(ops);
4108  },
4109  irUnits);
4110  return status;
4111 }
4112 
4113 //===----------------------------------------------------------------------===//
4114 // Partial Conversion
4115 //===----------------------------------------------------------------------===//
4116 
4118  ArrayRef<Operation *> ops, const ConversionTarget &target,
4120  return applyConversion(ops, target, patterns, config,
4121  OpConversionMode::Partial);
4122 }
4123 LogicalResult
4127  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
4128 }
4129 
4130 //===----------------------------------------------------------------------===//
4131 // Full Conversion
4132 //===----------------------------------------------------------------------===//
4133 
4135  const ConversionTarget &target,
4138  return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
4139 }
4141  const ConversionTarget &target,
4144  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
4145 }
4146 
4147 //===----------------------------------------------------------------------===//
4148 // Analysis Conversion
4149 //===----------------------------------------------------------------------===//
4150 
4151 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
4152 /// op is a top-level module op (which is expected to be isolated from above),
4153 /// return that op.
4155  // Check if there is a top-level operation within `ops`. If so, return that
4156  // op.
4157  for (Operation *op : ops) {
4158  if (!op->getParentOp()) {
4159 #ifndef NDEBUG
4160  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
4161  "expected top-level op to be isolated from above");
4162  for (Operation *other : ops)
4163  assert(op->isAncestor(other) &&
4164  "expected ops to have a common ancestor");
4165 #endif // NDEBUG
4166  return op;
4167  }
4168  }
4169 
4170  // No top-level op. Find a common ancestor.
4171  Operation *commonAncestor =
4172  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
4173  for (Operation *op : ops.drop_front()) {
4174  while (!commonAncestor->isProperAncestor(op)) {
4175  commonAncestor =
4177  assert(commonAncestor &&
4178  "expected to find a common isolated from above ancestor");
4179  }
4180  }
4181 
4182  return commonAncestor;
4183 }
4184 
4188 #ifndef NDEBUG
4189  if (config.legalizableOps)
4190  assert(config.legalizableOps->empty() && "expected empty set");
4191 #endif // NDEBUG
4192 
4193  // Clone closted common ancestor that is isolated from above.
4194  Operation *commonAncestor = findCommonAncestor(ops);
4195  IRMapping mapping;
4196  Operation *clonedAncestor = commonAncestor->clone(mapping);
4197  // Compute inverse IR mapping.
4198  DenseMap<Operation *, Operation *> inverseOperationMap;
4199  for (auto &it : mapping.getOperationMap())
4200  inverseOperationMap[it.second] = it.first;
4201 
4202  // Convert the cloned operations. The original IR will remain unchanged.
4203  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4204  ops, [&](Operation *op) { return mapping.lookup(op); });
4205  LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
4206  OpConversionMode::Analysis);
4207 
4208  // Remap `legalizableOps`, so that they point to the original ops and not the
4209  // cloned ops.
4210  if (config.legalizableOps) {
4211  DenseSet<Operation *> originalLegalizableOps;
4212  for (Operation *op : *config.legalizableOps)
4213  originalLegalizableOps.insert(inverseOperationMap[op]);
4214  *config.legalizableOps = std::move(originalLegalizableOps);
4215  }
4216 
4217  // Erase the cloned IR.
4218  clonedAncestor->erase();
4219  return status;
4220 }
4221 
4222 LogicalResult
4226  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
4227 }
static void setInsertionPointAfter(OpBuilder &b, Value value)
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value >> &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
#define DEBUG_TYPE
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps, const SetVector< Block * > &insertedBlocks)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:94
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:308
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:24
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
const ConversionConfig & getConfig() const
Return the configuration of the current dialect conversion.
void replaceAllUsesWith(Value from, ValueRange to)
Replace all the uses of from with to.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={}) override
PatternRewriter hook for inlining the ops of a block into another block.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Base class for the conversion patterns.
FailureOr< SmallVector< Value > > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:32
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:274
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:327
Block::iterator getPoint() const
Definition: Builders.h:340
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:337
Block * getBlock() const
Definition: Builders.h:339
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:472
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:610
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:420
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getTypes() const
Definition: ValueRange.cpp:28
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:304
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:718
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:606
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:236
result_range getResults()
Definition: Operation.h:415
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:896
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:365
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:140
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:632
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:622
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
CRTP Implementation of an action.
Definition: Action.h:76
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
@ AfterPatterns
Only attempt to fold not legal operations after applying patterns.
@ BeforePatterns
Only attempt to fold not legal operations before applying patterns.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
void reconcileUnrealizedCasts(const DenseSet< UnrealizedConversionCastOp > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
bool attachDebugMaterializationKind
If set to "true", the materialization kind ("source" or "target") will be attached to "builtin....
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:308
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:298
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:379
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:403
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:387
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:376
This struct represents a range of new types or a range of values that remaps an existing signature in...
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config)
void replaceAllUsesWith(Value from, ValueRange to, const TypeConverter *converter)
Replace the uses of the given value with the given values.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Block * > patternInsertedBlocks
A set of blocks that were inserted (newly-created blocks or moved blocks) by the current pattern.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
void replaceOp(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the results of the given operation with the given values and erase the operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.