MLIR  22.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Iterators.h"
17 #include "mlir/IR/Operation.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "llvm/Support/SaveAndRestore.h"
26 #include "llvm/Support/ScopedPrinter.h"
27 #include <optional>
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 #define DEBUG_TYPE "dialect-conversion"
33 
34 /// A utility function to log a successful result for the given reason.
35 template <typename... Args>
36 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
37  LLVM_DEBUG({
38  os.unindent();
39  os.startLine() << "} -> SUCCESS";
40  if (!fmt.empty())
41  os.getOStream() << " : "
42  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
43  os.getOStream() << "\n";
44  });
45 }
46 
47 /// A utility function to log a failure result for the given reason.
48 template <typename... Args>
49 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
50  LLVM_DEBUG({
51  os.unindent();
52  os.startLine() << "} -> FAILURE : "
53  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
54  << "\n";
55  });
56 }
57 
58 /// Helper function that computes an insertion point where the given value is
59 /// defined and can be used without a dominance violation.
61  Block *insertBlock = value.getParentBlock();
62  Block::iterator insertPt = insertBlock->begin();
63  if (OpResult inputRes = dyn_cast<OpResult>(value))
64  insertPt = ++inputRes.getOwner()->getIterator();
65  return OpBuilder::InsertPoint(insertBlock, insertPt);
66 }
67 
68 /// Helper function that computes an insertion point where the given values are
69 /// defined and can be used without a dominance violation.
71  assert(!vals.empty() && "expected at least one value");
72  DominanceInfo domInfo;
73  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
74  for (Value v : vals.drop_front()) {
75  // Choose the "later" insertion point.
77  if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
78  nextPt.getPoint())) {
79  // pt is before nextPt => choose nextPt.
80  pt = nextPt;
81  } else {
82 #ifndef NDEBUG
83  // nextPt should be before pt => choose pt.
84  // If pt, nextPt are no dominance relationship, then there is no valid
85  // insertion point at which all given values are defined.
86  bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
87  pt.getBlock(), pt.getPoint());
88  assert(dom && "unable to find valid insertion point");
89 #endif // NDEBUG
90  }
91  }
92  return pt;
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // ConversionValueMapping
97 //===----------------------------------------------------------------------===//
98 
99 /// A vector of SSA values, optimized for the most common case of a single
100 /// value.
101 using ValueVector = SmallVector<Value, 1>;
102 
103 namespace {
104 
105 /// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
106 struct ValueVectorMapInfo {
107  static ValueVector getEmptyKey() { return ValueVector{Value()}; }
108  static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
109  static ::llvm::hash_code getHashValue(const ValueVector &val) {
110  return ::llvm::hash_combine_range(val);
111  }
112  static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
113  return LHS == RHS;
114  }
115 };
116 
117 /// This class wraps a IRMapping to provide recursive lookup
118 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
119 struct ConversionValueMapping {
120  /// Return "true" if an SSA value is mapped to the given value. May return
121  /// false positives.
122  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
123 
124  /// Lookup a value in the mapping.
125  ValueVector lookup(const ValueVector &from) const;
126 
127  template <typename T>
128  struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
129 
130  /// Map a value vector to the one provided.
131  template <typename OldVal, typename NewVal>
132  std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
133  map(OldVal &&oldVal, NewVal &&newVal) {
134  LLVM_DEBUG({
135  ValueVector next(newVal);
136  while (true) {
137  assert(next != oldVal && "inserting cyclic mapping");
138  auto it = mapping.find(next);
139  if (it == mapping.end())
140  break;
141  next = it->second;
142  }
143  });
144  mappedTo.insert_range(newVal);
145 
146  mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
147  }
148 
149  /// Map a value vector or single value to the one provided.
150  template <typename OldVal, typename NewVal>
151  std::enable_if_t<!IsValueVector<OldVal>::value ||
152  !IsValueVector<NewVal>::value>
153  map(OldVal &&oldVal, NewVal &&newVal) {
154  if constexpr (IsValueVector<OldVal>{}) {
155  map(std::forward<OldVal>(oldVal), ValueVector{newVal});
156  } else if constexpr (IsValueVector<NewVal>{}) {
157  map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
158  } else {
159  map(ValueVector{oldVal}, ValueVector{newVal});
160  }
161  }
162 
163  void map(Value oldVal, SmallVector<Value> &&newVal) {
164  map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
165  }
166 
167  /// Drop the last mapping for the given values.
168  void erase(const ValueVector &value) { mapping.erase(value); }
169 
170 private:
171  /// Current value mappings.
173 
174  /// All SSA values that are mapped to. May contain false positives.
175  DenseSet<Value> mappedTo;
176 };
177 } // namespace
178 
179 /// Marker attribute for pure type conversions. I.e., mappings whose only
180 /// purpose is to resolve a type mismatch. (In contrast, mappings that point to
181 /// the replacement values of a "replaceOp" call, etc., are not pure type
182 /// conversions.)
183 static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
184 
185 /// Return the operation that defines all values in the vector. Return nullptr
186 /// if the values are not defined by the same operation.
187 static Operation *getCommonDefiningOp(const ValueVector &values) {
188  assert(!values.empty() && "expected non-empty value vector");
189  Operation *op = values.front().getDefiningOp();
190  for (Value v : llvm::drop_begin(values)) {
191  if (v.getDefiningOp() != op)
192  return nullptr;
193  }
194  return op;
195 }
196 
197 /// A vector of values is a pure type conversion if all values are defined by
198 /// the same operation and the operation has the `kPureTypeConversionMarker`
199 /// attribute.
200 static bool isPureTypeConversion(const ValueVector &values) {
201  assert(!values.empty() && "expected non-empty value vector");
202  Operation *op = getCommonDefiningOp(values);
203  return op && op->hasAttr(kPureTypeConversionMarker);
204 }
205 
206 ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
207  auto it = mapping.find(from);
208  if (it == mapping.end()) {
209  // No mapping found: The lookup stops here.
210  return {};
211  }
212  return it->second;
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // Rewriter and Translation State
217 //===----------------------------------------------------------------------===//
218 namespace {
219 /// This class contains a snapshot of the current conversion rewriter state.
220 /// This is useful when saving and undoing a set of rewrites.
221 struct RewriterState {
222  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
223  unsigned numReplacedOps)
224  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
225  numReplacedOps(numReplacedOps) {}
226 
227  /// The current number of rewrites performed.
228  unsigned numRewrites;
229 
230  /// The current number of ignored operations.
231  unsigned numIgnoredOperations;
232 
233  /// The current number of replaced ops that are scheduled for erasure.
234  unsigned numReplacedOps;
235 };
236 
237 //===----------------------------------------------------------------------===//
238 // IR rewrites
239 //===----------------------------------------------------------------------===//
240 
241 static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
242 
243 /// Notify the listener that the given block and its contents are being erased.
244 static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
245  for (Operation &op : b)
246  notifyIRErased(listener, op);
247  listener->notifyBlockErased(&b);
248 }
249 
250 /// Notify the listener that the given operation and its contents are being
251 /// erased.
252 static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
253  for (Region &r : op.getRegions()) {
254  for (Block &b : r) {
255  notifyIRErased(listener, b);
256  }
257  }
258  listener->notifyOperationErased(&op);
259 }
260 
261 /// An IR rewrite that can be committed (upon success) or rolled back (upon
262 /// failure).
263 ///
264 /// The dialect conversion keeps track of IR modifications (requested by the
265 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
266 /// are directly applied to the IR as the rewriter API is used, some are applied
267 /// partially, and some are delayed until the `IRRewrite` objects are committed.
268 class IRRewrite {
269 public:
270  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
271  /// Enum values are ordered, so that they can be used in `classof`: first all
272  /// block rewrites, then all operation rewrites.
273  enum class Kind {
274  // Block rewrites
275  CreateBlock,
276  EraseBlock,
277  InlineBlock,
278  MoveBlock,
279  BlockTypeConversion,
280  ReplaceBlockArg,
281  // Operation rewrites
282  MoveOperation,
283  ModifyOperation,
284  ReplaceOperation,
285  CreateOperation,
286  UnresolvedMaterialization
287  };
288 
289  virtual ~IRRewrite() = default;
290 
291  /// Roll back the rewrite. Operations may be erased during rollback.
292  virtual void rollback() = 0;
293 
294  /// Commit the rewrite. At this point, it is certain that the dialect
295  /// conversion will succeed. All IR modifications, except for operation/block
296  /// erasure, must be performed through the given rewriter.
297  ///
298  /// Instead of erasing operations/blocks, they should merely be unlinked
299  /// commit phase and finally be erased during the cleanup phase. This is
300  /// because internal dialect conversion state (such as `mapping`) may still
301  /// be using them.
302  ///
303  /// Any IR modification that was already performed before the commit phase
304  /// (e.g., insertion of an op) must be communicated to the listener that may
305  /// be attached to the given rewriter.
306  virtual void commit(RewriterBase &rewriter) {}
307 
308  /// Cleanup operations/blocks. Cleanup is called after commit.
309  virtual void cleanup(RewriterBase &rewriter) {}
310 
311  Kind getKind() const { return kind; }
312 
313  static bool classof(const IRRewrite *rewrite) { return true; }
314 
315 protected:
316  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
317  : kind(kind), rewriterImpl(rewriterImpl) {}
318 
319  const ConversionConfig &getConfig() const;
320 
321  const Kind kind;
322  ConversionPatternRewriterImpl &rewriterImpl;
323 };
324 
325 /// A block rewrite.
326 class BlockRewrite : public IRRewrite {
327 public:
328  /// Return the block that this rewrite operates on.
329  Block *getBlock() const { return block; }
330 
331  static bool classof(const IRRewrite *rewrite) {
332  return rewrite->getKind() >= Kind::CreateBlock &&
333  rewrite->getKind() <= Kind::ReplaceBlockArg;
334  }
335 
336 protected:
337  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
338  Block *block)
339  : IRRewrite(kind, rewriterImpl), block(block) {}
340 
341  // The block that this rewrite operates on.
342  Block *block;
343 };
344 
345 /// Creation of a block. Block creations are immediately reflected in the IR.
346 /// There is no extra work to commit the rewrite. During rollback, the newly
347 /// created block is erased.
348 class CreateBlockRewrite : public BlockRewrite {
349 public:
350  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
351  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
352 
353  static bool classof(const IRRewrite *rewrite) {
354  return rewrite->getKind() == Kind::CreateBlock;
355  }
356 
357  void commit(RewriterBase &rewriter) override {
358  // The block was already created and inserted. Just inform the listener.
359  if (auto *listener = rewriter.getListener())
360  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
361  }
362 
363  void rollback() override {
364  // Unlink all of the operations within this block, they will be deleted
365  // separately.
366  auto &blockOps = block->getOperations();
367  while (!blockOps.empty())
368  blockOps.remove(blockOps.begin());
369  block->dropAllUses();
370  if (block->getParent())
371  block->erase();
372  else
373  delete block;
374  }
375 };
376 
377 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
378 /// blocks are immediately unlinked, but only erased during cleanup. This makes
379 /// it easier to rollback a block erasure: the block is simply inserted into its
380 /// original location.
381 class EraseBlockRewrite : public BlockRewrite {
382 public:
383  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
384  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
385  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
386 
387  static bool classof(const IRRewrite *rewrite) {
388  return rewrite->getKind() == Kind::EraseBlock;
389  }
390 
391  ~EraseBlockRewrite() override {
392  assert(!block &&
393  "rewrite was neither rolled back nor committed/cleaned up");
394  }
395 
396  void rollback() override {
397  // The block (owned by this rewrite) was not actually erased yet. It was
398  // just unlinked. Put it back into its original position.
399  assert(block && "expected block");
400  auto &blockList = region->getBlocks();
401  Region::iterator before = insertBeforeBlock
402  ? Region::iterator(insertBeforeBlock)
403  : blockList.end();
404  blockList.insert(before, block);
405  block = nullptr;
406  }
407 
408  void commit(RewriterBase &rewriter) override {
409  assert(block && "expected block");
410 
411  // Notify the listener that the block and its contents are being erased.
412  if (auto *listener =
413  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
414  notifyIRErased(listener, *block);
415  }
416 
417  void cleanup(RewriterBase &rewriter) override {
418  // Erase the contents of the block.
419  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
420  rewriter.eraseOp(&op);
421  assert(block->empty() && "expected empty block");
422 
423  // Erase the block.
424  block->dropAllDefinedValueUses();
425  delete block;
426  block = nullptr;
427  }
428 
429 private:
430  // The region in which this block was previously contained.
431  Region *region;
432 
433  // The original successor of this block before it was unlinked. "nullptr" if
434  // this block was the only block in the region.
435  Block *insertBeforeBlock;
436 };
437 
438 /// Inlining of a block. This rewrite is immediately reflected in the IR.
439 /// Note: This rewrite represents only the inlining of the operations. The
440 /// erasure of the inlined block is a separate rewrite.
441 class InlineBlockRewrite : public BlockRewrite {
442 public:
443  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
444  Block *sourceBlock, Block::iterator before)
445  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
446  sourceBlock(sourceBlock),
447  firstInlinedInst(sourceBlock->empty() ? nullptr
448  : &sourceBlock->front()),
449  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
450  // If a listener is attached to the dialect conversion, ops must be moved
451  // one-by-one. When they are moved in bulk, notifications cannot be sent
452  // because the ops that used to be in the source block at the time of the
453  // inlining (before the "commit" phase) are unknown at the time when
454  // notifications are sent (which is during the "commit" phase).
455  assert(!getConfig().listener &&
456  "InlineBlockRewrite not supported if listener is attached");
457  }
458 
459  static bool classof(const IRRewrite *rewrite) {
460  return rewrite->getKind() == Kind::InlineBlock;
461  }
462 
463  void rollback() override {
464  // Put the operations from the destination block (owned by the rewrite)
465  // back into the source block.
466  if (firstInlinedInst) {
467  assert(lastInlinedInst && "expected operation");
468  sourceBlock->getOperations().splice(sourceBlock->begin(),
469  block->getOperations(),
470  Block::iterator(firstInlinedInst),
471  ++Block::iterator(lastInlinedInst));
472  }
473  }
474 
475 private:
476  // The block that originally contained the operations.
477  Block *sourceBlock;
478 
479  // The first inlined operation.
480  Operation *firstInlinedInst;
481 
482  // The last inlined operation.
483  Operation *lastInlinedInst;
484 };
485 
486 /// Moving of a block. This rewrite is immediately reflected in the IR.
487 class MoveBlockRewrite : public BlockRewrite {
488 public:
489  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
490  Region *previousRegion, Region::iterator previousIt)
491  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
492  region(previousRegion),
493  insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
494  : &*previousIt) {}
495 
496  static bool classof(const IRRewrite *rewrite) {
497  return rewrite->getKind() == Kind::MoveBlock;
498  }
499 
500  void commit(RewriterBase &rewriter) override {
501  // The block was already moved. Just inform the listener.
502  if (auto *listener = rewriter.getListener()) {
503  // Note: `previousIt` cannot be passed because this is a delayed
504  // notification and iterators into past IR state cannot be represented.
505  listener->notifyBlockInserted(block, /*previous=*/region,
506  /*previousIt=*/{});
507  }
508  }
509 
510  void rollback() override {
511  // Move the block back to its original position.
512  Region::iterator before =
513  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
514  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
515  }
516 
517 private:
518  // The region in which this block was previously contained.
519  Region *region;
520 
521  // The original successor of this block before it was moved. "nullptr" if
522  // this block was the only block in the region.
523  Block *insertBeforeBlock;
524 };
525 
526 /// Block type conversion. This rewrite is partially reflected in the IR.
527 class BlockTypeConversionRewrite : public BlockRewrite {
528 public:
529  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
530  Block *origBlock, Block *newBlock)
531  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
532  newBlock(newBlock) {}
533 
534  static bool classof(const IRRewrite *rewrite) {
535  return rewrite->getKind() == Kind::BlockTypeConversion;
536  }
537 
538  Block *getOrigBlock() const { return block; }
539 
540  Block *getNewBlock() const { return newBlock; }
541 
542  void commit(RewriterBase &rewriter) override;
543 
544  void rollback() override;
545 
546 private:
547  /// The new block that was created as part of this signature conversion.
548  Block *newBlock;
549 };
550 
551 /// Replacing a block argument. This rewrite is not immediately reflected in the
552 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
553 /// until the rewrite is committed.
554 class ReplaceBlockArgRewrite : public BlockRewrite {
555 public:
556  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
557  Block *block, BlockArgument arg,
558  const TypeConverter *converter)
559  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
560  converter(converter) {}
561 
562  static bool classof(const IRRewrite *rewrite) {
563  return rewrite->getKind() == Kind::ReplaceBlockArg;
564  }
565 
566  void commit(RewriterBase &rewriter) override;
567 
568  void rollback() override;
569 
570 private:
571  BlockArgument arg;
572 
573  /// The current type converter when the block argument was replaced.
574  const TypeConverter *converter;
575 };
576 
577 /// An operation rewrite.
578 class OperationRewrite : public IRRewrite {
579 public:
580  /// Return the operation that this rewrite operates on.
581  Operation *getOperation() const { return op; }
582 
583  static bool classof(const IRRewrite *rewrite) {
584  return rewrite->getKind() >= Kind::MoveOperation &&
585  rewrite->getKind() <= Kind::UnresolvedMaterialization;
586  }
587 
588 protected:
589  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
590  Operation *op)
591  : IRRewrite(kind, rewriterImpl), op(op) {}
592 
593  // The operation that this rewrite operates on.
594  Operation *op;
595 };
596 
597 /// Moving of an operation. This rewrite is immediately reflected in the IR.
598 class MoveOperationRewrite : public OperationRewrite {
599 public:
600  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
601  Operation *op, OpBuilder::InsertPoint previous)
602  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
603  block(previous.getBlock()),
604  insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
605  ? nullptr
606  : &*previous.getPoint()) {}
607 
608  static bool classof(const IRRewrite *rewrite) {
609  return rewrite->getKind() == Kind::MoveOperation;
610  }
611 
612  void commit(RewriterBase &rewriter) override {
613  // The operation was already moved. Just inform the listener.
614  if (auto *listener = rewriter.getListener()) {
615  // Note: `previousIt` cannot be passed because this is a delayed
616  // notification and iterators into past IR state cannot be represented.
617  listener->notifyOperationInserted(
618  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
619  /*insertPt=*/{}));
620  }
621  }
622 
623  void rollback() override {
624  // Move the operation back to its original position.
625  Block::iterator before =
626  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
627  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
628  }
629 
630 private:
631  // The block in which this operation was previously contained.
632  Block *block;
633 
634  // The original successor of this operation before it was moved. "nullptr"
635  // if this operation was the only operation in the region.
636  Operation *insertBeforeOp;
637 };
638 
639 /// In-place modification of an op. This rewrite is immediately reflected in
640 /// the IR. The previous state of the operation is stored in this object.
641 class ModifyOperationRewrite : public OperationRewrite {
642 public:
643  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
644  Operation *op)
645  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
646  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
647  operands(op->operand_begin(), op->operand_end()),
648  successors(op->successor_begin(), op->successor_end()) {
649  if (OpaqueProperties prop = op->getPropertiesStorage()) {
650  // Make a copy of the properties.
651  propertiesStorage = operator new(op->getPropertiesStorageSize());
652  OpaqueProperties propCopy(propertiesStorage);
653  name.initOpProperties(propCopy, /*init=*/prop);
654  }
655  }
656 
657  static bool classof(const IRRewrite *rewrite) {
658  return rewrite->getKind() == Kind::ModifyOperation;
659  }
660 
661  ~ModifyOperationRewrite() override {
662  assert(!propertiesStorage &&
663  "rewrite was neither committed nor rolled back");
664  }
665 
666  void commit(RewriterBase &rewriter) override {
667  // Notify the listener that the operation was modified in-place.
668  if (auto *listener =
669  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
670  listener->notifyOperationModified(op);
671 
672  if (propertiesStorage) {
673  OpaqueProperties propCopy(propertiesStorage);
674  // Note: The operation may have been erased in the mean time, so
675  // OperationName must be stored in this object.
676  name.destroyOpProperties(propCopy);
677  operator delete(propertiesStorage);
678  propertiesStorage = nullptr;
679  }
680  }
681 
682  void rollback() override {
683  op->setLoc(loc);
684  op->setAttrs(attrs);
685  op->setOperands(operands);
686  for (const auto &it : llvm::enumerate(successors))
687  op->setSuccessor(it.value(), it.index());
688  if (propertiesStorage) {
689  OpaqueProperties propCopy(propertiesStorage);
690  op->copyProperties(propCopy);
691  name.destroyOpProperties(propCopy);
692  operator delete(propertiesStorage);
693  propertiesStorage = nullptr;
694  }
695  }
696 
697 private:
698  OperationName name;
699  LocationAttr loc;
700  DictionaryAttr attrs;
701  SmallVector<Value, 8> operands;
702  SmallVector<Block *, 2> successors;
703  void *propertiesStorage = nullptr;
704 };
705 
706 /// Replacing an operation. Erasing an operation is treated as a special case
707 /// with "null" replacements. This rewrite is not immediately reflected in the
708 /// IR. An internal IR mapping is updated, but values are not replaced and the
709 /// original op is not erased until the rewrite is committed.
710 class ReplaceOperationRewrite : public OperationRewrite {
711 public:
712  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
713  Operation *op, const TypeConverter *converter)
714  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
715  converter(converter) {}
716 
717  static bool classof(const IRRewrite *rewrite) {
718  return rewrite->getKind() == Kind::ReplaceOperation;
719  }
720 
721  void commit(RewriterBase &rewriter) override;
722 
723  void rollback() override;
724 
725  void cleanup(RewriterBase &rewriter) override;
726 
727 private:
728  /// An optional type converter that can be used to materialize conversions
729  /// between the new and old values if necessary.
730  const TypeConverter *converter;
731 };
732 
733 class CreateOperationRewrite : public OperationRewrite {
734 public:
735  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
736  Operation *op)
737  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
738 
739  static bool classof(const IRRewrite *rewrite) {
740  return rewrite->getKind() == Kind::CreateOperation;
741  }
742 
743  void commit(RewriterBase &rewriter) override {
744  // The operation was already created and inserted. Just inform the listener.
745  if (auto *listener = rewriter.getListener())
746  listener->notifyOperationInserted(op, /*previous=*/{});
747  }
748 
749  void rollback() override;
750 };
751 
752 /// The type of materialization.
753 enum MaterializationKind {
754  /// This materialization materializes a conversion from an illegal type to a
755  /// legal one.
756  Target,
757 
758  /// This materialization materializes a conversion from a legal type back to
759  /// an illegal one.
760  Source
761 };
762 
763 /// Helper class that stores metadata about an unresolved materialization.
764 class UnresolvedMaterializationInfo {
765 public:
766  UnresolvedMaterializationInfo() = default;
767  UnresolvedMaterializationInfo(const TypeConverter *converter,
768  MaterializationKind kind, Type originalType)
769  : converterAndKind(converter, kind), originalType(originalType) {}
770 
771  /// Return the type converter of this materialization (which may be null).
772  const TypeConverter *getConverter() const {
773  return converterAndKind.getPointer();
774  }
775 
776  /// Return the kind of this materialization.
777  MaterializationKind getMaterializationKind() const {
778  return converterAndKind.getInt();
779  }
780 
781  /// Return the original type of the SSA value.
782  Type getOriginalType() const { return originalType; }
783 
784 private:
785  /// The corresponding type converter to use when resolving this
786  /// materialization, and the kind of this materialization.
787  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
788  converterAndKind;
789 
790  /// The original type of the SSA value. Only used for target
791  /// materializations.
792  Type originalType;
793 };
794 
795 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
796 /// op. Unresolved materializations fold away or are replaced with
797 /// source/target materializations at the end of the dialect conversion.
798 class UnresolvedMaterializationRewrite : public OperationRewrite {
799 public:
800  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
801  UnrealizedConversionCastOp op,
802  ValueVector mappedValues)
803  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
804  mappedValues(std::move(mappedValues)) {}
805 
806  static bool classof(const IRRewrite *rewrite) {
807  return rewrite->getKind() == Kind::UnresolvedMaterialization;
808  }
809 
810  void rollback() override;
811 
812  UnrealizedConversionCastOp getOperation() const {
813  return cast<UnrealizedConversionCastOp>(op);
814  }
815 
816 private:
817  /// The values in the conversion value mapping that are being replaced by the
818  /// results of this unresolved materialization.
819  ValueVector mappedValues;
820 };
821 } // namespace
822 
823 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
824 /// Return "true" if there is an operation rewrite that matches the specified
825 /// rewrite type and operation among the given rewrites.
826 template <typename RewriteTy, typename R>
827 static bool hasRewrite(R &&rewrites, Operation *op) {
828  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
829  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
830  return rewriteTy && rewriteTy->getOperation() == op;
831  });
832 }
833 
834 /// Return "true" if there is a block rewrite that matches the specified
835 /// rewrite type and block among the given rewrites.
836 template <typename RewriteTy, typename R>
837 static bool hasRewrite(R &&rewrites, Block *block) {
838  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
839  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
840  return rewriteTy && rewriteTy->getBlock() == block;
841  });
842 }
843 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
844 
845 //===----------------------------------------------------------------------===//
846 // ConversionPatternRewriterImpl
847 //===----------------------------------------------------------------------===//
848 namespace mlir {
849 namespace detail {
852  const ConversionConfig &config)
853  : rewriter(rewriter), config(config),
854  notifyingRewriter(rewriter.getContext(), config.listener) {}
855 
856  //===--------------------------------------------------------------------===//
857  // State Management
858  //===--------------------------------------------------------------------===//
859 
860  /// Return the current state of the rewriter.
861  RewriterState getCurrentState();
862 
863  /// Apply all requested operation rewrites. This method is invoked when the
864  /// conversion process succeeds.
865  void applyRewrites();
866 
867  /// Reset the state of the rewriter to a previously saved point. Optionally,
868  /// the name of the pattern that triggered the rollback can specified for
869  /// debugging purposes.
870  void resetState(RewriterState state, StringRef patternName = "");
871 
872  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
873  /// failure.
874  template <typename RewriteTy, typename... Args>
875  void appendRewrite(Args &&...args) {
876  assert(config.allowPatternRollback && "appending rewrites is not allowed");
877  rewrites.push_back(
878  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
879  }
880 
881  /// Undo the rewrites (motions, splits) one by one in reverse order until
882  /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
883  /// that triggered the rollback can specified for debugging purposes.
884  void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
885 
886  /// Remap the given values to those with potentially different types. Returns
887  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
888  /// is the tag used when describing a value within a diagnostic, e.g.
889  /// "operand".
890  LogicalResult remapValues(StringRef valueDiagTag,
891  std::optional<Location> inputLoc, ValueRange values,
892  SmallVector<ValueVector> &remapped);
893 
894  /// Return "true" if the given operation is ignored, and does not need to be
895  /// converted.
896  bool isOpIgnored(Operation *op) const;
897 
898  /// Return "true" if the given operation was replaced or erased.
899  bool wasOpReplaced(Operation *op) const;
900 
901  /// Lookup the most recently mapped values with the desired types in the
902  /// mapping, taking into account only replacements. Perform a best-effort
903  /// search for existing materializations with the desired types.
904  ///
905  /// If `skipPureTypeConversions` is "true", materializations that are pure
906  /// type conversions are not considered.
907  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
908  bool skipPureTypeConversions = false) const;
909 
910  /// Lookup the given value within the map, or return an empty vector if the
911  /// value is not mapped. If it is mapped, this follows the same behavior
912  /// as `lookupOrDefault`.
913  ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
914 
915  //===--------------------------------------------------------------------===//
916  // IR Rewrites / Type Conversion
917  //===--------------------------------------------------------------------===//
918 
919  /// Convert the types of block arguments within the given region.
920  FailureOr<Block *>
921  convertRegionTypes(Region *region, const TypeConverter &converter,
922  TypeConverter::SignatureConversion *entryConversion);
923 
924  /// Apply the given signature conversion on the given block. The new block
925  /// containing the updated signature is returned. If no conversions were
926  /// necessary, e.g. if the block has no arguments, `block` is returned.
927  /// `converter` is used to generate any necessary cast operations that
928  /// translate between the origin argument types and those specified in the
929  /// signature conversion.
930  Block *applySignatureConversion(
931  Block *block, const TypeConverter *converter,
932  TypeConverter::SignatureConversion &signatureConversion);
933 
934  /// Replace the results of the given operation with the given values and
935  /// erase the operation.
936  ///
937  /// There can be multiple replacement values for each result (1:N
938  /// replacement). If the replacement values are empty, the respective result
939  /// is dropped and a source materialization is built if the result still has
940  /// uses.
941  void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
942 
943  /// Replace the given block argument with the given values. The specified
944  /// converter is used to build materializations (if necessary).
945  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
946  const TypeConverter *converter);
947 
948  /// Erase the given block and its contents.
949  void eraseBlock(Block *block);
950 
951  /// Inline the source block into the destination block before the given
952  /// iterator.
953  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
954 
955  //===--------------------------------------------------------------------===//
956  // Materializations
957  //===--------------------------------------------------------------------===//
958 
959  /// Build an unresolved materialization operation given a range of output
960  /// types and a list of input operands. Returns the inputs if they their
961  /// types match the output types.
962  ///
963  /// If a cast op was built, it can optionally be returned with the `castOp`
964  /// output argument.
965  ///
966  /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
967  /// the results of the unresolved materialization in the conversion value
968  /// mapping.
969  ///
970  /// If `isPureTypeConversion` is "true", the materialization is created only
971  /// to resolve a type mismatch. That means it is not a regular value
972  /// replacement issued by the user. (Replacement values that are created
973  /// "out of thin air" appear like unresolved materializations because they are
974  /// unrealized_conversion_cast ops. However, they must be treated like
975  /// regular value replacements.)
976  ValueRange buildUnresolvedMaterialization(
977  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
978  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
979  Type originalType, const TypeConverter *converter,
980  bool isPureTypeConversion = true);
981 
982  /// Find a replacement value for the given SSA value in the conversion value
983  /// mapping. The replacement value must have the same type as the given SSA
984  /// value. If there is no replacement value with the correct type, find the
985  /// latest replacement value (regardless of the type) and build a source
986  /// materialization.
987  Value findOrBuildReplacementValue(Value value,
988  const TypeConverter *converter);
989 
990  //===--------------------------------------------------------------------===//
991  // Rewriter Notification Hooks
992  //===--------------------------------------------------------------------===//
993 
994  //// Notifies that an op was inserted.
995  void notifyOperationInserted(Operation *op,
996  OpBuilder::InsertPoint previous) override;
997 
998  /// Notifies that a block was inserted.
999  void notifyBlockInserted(Block *block, Region *previous,
1000  Region::iterator previousIt) override;
1001 
1002  /// Notifies that a pattern match failed for the given reason.
1003  void
1004  notifyMatchFailure(Location loc,
1005  function_ref<void(Diagnostic &)> reasonCallback) override;
1006 
1007  //===--------------------------------------------------------------------===//
1008  // IR Erasure
1009  //===--------------------------------------------------------------------===//
1010 
1011  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1012  /// operation or block is erased multiple times. This rewriter assumes that
1013  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1015  public:
1017  MLIRContext *context,
1018  std::function<void(Operation *)> opErasedCallback = nullptr)
1019  : RewriterBase(context, /*listener=*/this),
1020  opErasedCallback(opErasedCallback) {}
1021 
1022  /// Erase the given op (unless it was already erased).
1023  void eraseOp(Operation *op) override {
1024  if (wasErased(op))
1025  return;
1026  op->dropAllUses();
1028  }
1029 
1030  /// Erase the given block (unless it was already erased).
1031  void eraseBlock(Block *block) override {
1032  if (wasErased(block))
1033  return;
1034  assert(block->empty() && "expected empty block");
1035  block->dropAllDefinedValueUses();
1036  RewriterBase::eraseBlock(block);
1037  }
1038 
1039  bool wasErased(void *ptr) const { return erased.contains(ptr); }
1040 
1041  void notifyOperationErased(Operation *op) override {
1042  erased.insert(op);
1043  if (opErasedCallback)
1044  opErasedCallback(op);
1045  }
1046 
1047  void notifyBlockErased(Block *block) override { erased.insert(block); }
1048 
1049  private:
1050  /// Pointers to all erased operations and blocks.
1051  DenseSet<void *> erased;
1052 
1053  /// A callback that is invoked when an operation is erased.
1054  std::function<void(Operation *)> opErasedCallback;
1055  };
1056 
1057  //===--------------------------------------------------------------------===//
1058  // State
1059  //===--------------------------------------------------------------------===//
1060 
1061  /// The rewriter that is used to perform the conversion.
1063 
1064  // Mapping between replaced values that differ in type. This happens when
1065  // replacing a value with one of a different type.
1066  ConversionValueMapping mapping;
1067 
1068  /// Ordered list of block operations (creations, splits, motions).
1069  /// This vector is maintained only if `allowPatternRollback` is set to
1070  /// "true". Otherwise, all IR rewrites are materialized immediately and no
1071  /// bookkeeping is needed.
1073 
1074  /// A set of operations that should no longer be considered for legalization.
1075  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1076  /// tracked separately.
1078 
1079  /// A set of operations that were replaced/erased. Such ops are not erased
1080  /// immediately but only when the dialect conversion succeeds. In the mean
1081  /// time, they should no longer be considered for legalization and any attempt
1082  /// to modify/access them is invalid rewriter API usage.
1084 
1085  /// A set of operations that were created by the current pattern.
1087 
1088  /// A set of operations that were modified by the current pattern.
1090 
1091  /// A set of blocks that were inserted (newly-created blocks or moved blocks)
1092  /// by the current pattern.
1094 
1095  /// A list of unresolved materializations that were created by the current
1096  /// pattern.
1098 
1099  /// A mapping for looking up metadata of unresolved materializations.
1102 
1103  /// The current type converter, or nullptr if no type converter is currently
1104  /// active.
1105  const TypeConverter *currentTypeConverter = nullptr;
1106 
1107  /// A mapping of regions to type converters that should be used when
1108  /// converting the arguments of blocks within that region.
1110 
1111  /// Dialect conversion configuration.
1113 
1114  /// A set of erased operations. This set is utilized only if
1115  /// `allowPatternRollback` is set to "false". Conceptually, this set is
1116  /// similar to `replacedOps` (which is maintained when the flag is set to
1117  /// "true"). However, erasing from a DenseSet is more efficient than erasing
1118  /// from a SetVector.
1120 
1121  /// A set of erased blocks. This set is utilized only if
1122  /// `allowPatternRollback` is set to "false".
1124 
1125  /// A rewriter that notifies the listener (if any) about all IR
1126  /// modifications. This rewriter is utilized only if `allowPatternRollback`
1127  /// is set to "false". If the flag is set to "true", the listener is notified
1128  /// with a separate mechanism (e.g., in `IRRewrite::commit`).
1130 
1131 #ifndef NDEBUG
1132  /// A set of replaced block arguments. This set is for debugging purposes
1133  /// only and it is maintained only if `allowPatternRollback` is set to
1134  /// "true".
1136 
1137  /// A set of operations that have pending updates. This tracking isn't
1138  /// strictly necessary, and is thus only active during debug builds for extra
1139  /// verification.
1141 
1142  /// A raw output stream used to prefix the debug log.
1143  llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
1144  llvm::dbgs()};
1145 
1146  /// A logger used to emit diagnostics during the conversion process.
1147  llvm::ScopedPrinter logger{os};
1148  std::string logPrefix;
1149 #endif
1150 };
1151 } // namespace detail
1152 } // namespace mlir
1153 
1154 const ConversionConfig &IRRewrite::getConfig() const {
1155  return rewriterImpl.config;
1156 }
1157 
1158 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1159  // Inform the listener about all IR modifications that have already taken
1160  // place: References to the original block have been replaced with the new
1161  // block.
1162  if (auto *listener =
1163  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1164  for (Operation *op : getNewBlock()->getUsers())
1165  listener->notifyOperationModified(op);
1166 }
1167 
1168 void BlockTypeConversionRewrite::rollback() {
1169  getNewBlock()->replaceAllUsesWith(getOrigBlock());
1170 }
1171 
1173  Value repl) {
1174  if (isa<BlockArgument>(repl)) {
1175  rewriter.replaceAllUsesWith(arg, repl);
1176  return;
1177  }
1178 
1179  // If the replacement value is an operation, we check to make sure that we
1180  // don't replace uses that are within the parent operation of the
1181  // replacement value.
1182  Operation *replOp = cast<OpResult>(repl).getOwner();
1183  Block *replBlock = replOp->getBlock();
1184  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1185  Operation *user = operand.getOwner();
1186  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1187  });
1188 }
1189 
1190 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1191  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
1192  if (!repl)
1193  return;
1194  performReplaceBlockArg(rewriter, arg, repl);
1195 }
1196 
1197 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
1198 
1199 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1200  auto *listener =
1201  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1202 
1203  // Compute replacement values.
1204  SmallVector<Value> replacements =
1205  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1206  return rewriterImpl.findOrBuildReplacementValue(result, converter);
1207  });
1208 
1209  // Notify the listener that the operation is about to be replaced.
1210  if (listener)
1211  listener->notifyOperationReplaced(op, replacements);
1212 
1213  // Replace all uses with the new values.
1214  for (auto [result, newValue] :
1215  llvm::zip_equal(op->getResults(), replacements))
1216  if (newValue)
1217  rewriter.replaceAllUsesWith(result, newValue);
1218 
1219  // The original op will be erased, so remove it from the set of unlegalized
1220  // ops.
1221  if (getConfig().unlegalizedOps)
1222  getConfig().unlegalizedOps->erase(op);
1223 
1224  // Notify the listener that the operation and its contents are being erased.
1225  if (listener)
1226  notifyIRErased(listener, *op);
1227 
1228  // Do not erase the operation yet. It may still be referenced in `mapping`.
1229  // Just unlink it for now and erase it during cleanup.
1230  op->getBlock()->getOperations().remove(op);
1231 }
1232 
1233 void ReplaceOperationRewrite::rollback() {
1234  for (auto result : op->getResults())
1235  rewriterImpl.mapping.erase({result});
1236 }
1237 
1238 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1239  rewriter.eraseOp(op);
1240 }
1241 
1242 void CreateOperationRewrite::rollback() {
1243  for (Region &region : op->getRegions()) {
1244  while (!region.getBlocks().empty())
1245  region.getBlocks().remove(region.getBlocks().begin());
1246  }
1247  op->dropAllUses();
1248  op->erase();
1249 }
1250 
1251 void UnresolvedMaterializationRewrite::rollback() {
1252  if (!mappedValues.empty())
1253  rewriterImpl.mapping.erase(mappedValues);
1254  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1255  op->erase();
1256 }
1257 
1259  // Commit all rewrites. Use a new rewriter, so the modifications are not
1260  // tracked for rollback purposes etc.
1261  IRRewriter irRewriter(rewriter.getContext(), config.listener);
1262  // Note: New rewrites may be added during the "commit" phase and the
1263  // `rewrites` vector may reallocate.
1264  for (size_t i = 0; i < rewrites.size(); ++i)
1265  rewrites[i]->commit(irRewriter);
1266 
1267  // Clean up all rewrites.
1268  SingleEraseRewriter eraseRewriter(
1269  rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
1270  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1271  unresolvedMaterializations.erase(castOp);
1272  });
1273  for (auto &rewrite : rewrites)
1274  rewrite->cleanup(eraseRewriter);
1275 }
1276 
1277 //===----------------------------------------------------------------------===//
1278 // State Management
1279 //===----------------------------------------------------------------------===//
1280 
1282  Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1283  // Helper function that looks up a single value.
1284  auto lookup = [&](const ValueVector &values) -> ValueVector {
1285  assert(!values.empty() && "expected non-empty value vector");
1286 
1287  // If the pattern rollback is enabled, use the mapping to look up the
1288  // values.
1290  return mapping.lookup(values);
1291 
1292  // Otherwise, look up values by examining the IR. All replacements have
1293  // already been materialized in IR.
1294  Operation *op = getCommonDefiningOp(values);
1295  if (!op)
1296  return {};
1297  auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1298  if (!castOp)
1299  return {};
1300  if (!this->unresolvedMaterializations.contains(castOp))
1301  return {};
1302  if (castOp.getOutputs() != values)
1303  return {};
1304  return castOp.getInputs();
1305  };
1306 
1307  // Helper function that looks up each value in `values` individually and then
1308  // composes the results. If that fails, it tries to look up the entire vector
1309  // at once.
1310  auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1311  // If possible, replace each value with (one or multiple) mapped values.
1312  ValueVector next;
1313  for (Value v : values) {
1314  ValueVector r = lookup({v});
1315  if (!r.empty()) {
1316  llvm::append_range(next, r);
1317  } else {
1318  next.push_back(v);
1319  }
1320  }
1321  if (next != values) {
1322  // At least one value was replaced.
1323  return next;
1324  }
1325 
1326  // Otherwise: Check if there is a mapping for the entire vector. Such
1327  // mappings are materializations. (N:M mapping are not supported for value
1328  // replacements.)
1329  //
1330  // Note: From a correctness point of view, materializations do not have to
1331  // be stored (and looked up) in the mapping. But for performance reasons,
1332  // we choose to reuse existing IR (when possible) instead of creating it
1333  // multiple times.
1334  ValueVector r = lookup(values);
1335  if (r.empty()) {
1336  // No mapping found: The lookup stops here.
1337  return {};
1338  }
1339  return r;
1340  };
1341 
1342  // Try to find the deepest values that have the desired types. If there is no
1343  // such mapping, simply return the deepest values.
1344  ValueVector desiredValue;
1345  ValueVector current{from};
1346  ValueVector lastNonMaterialization{from};
1347  do {
1348  // Store the current value if the types match.
1349  bool match = TypeRange(ValueRange(current)) == desiredTypes;
1350  if (skipPureTypeConversions) {
1351  // Skip pure type conversions, if requested.
1352  bool pureConversion = isPureTypeConversion(current);
1353  match &= !pureConversion;
1354  // Keep track of the last mapped value that was not a pure type
1355  // conversion.
1356  if (!pureConversion)
1357  lastNonMaterialization = current;
1358  }
1359  if (match)
1360  desiredValue = current;
1361 
1362  // Lookup next value in the mapping.
1363  ValueVector next = composedLookup(current);
1364  if (next.empty())
1365  break;
1366  current = std::move(next);
1367  } while (true);
1368 
1369  // If the desired values were found use them, otherwise default to the leaf
1370  // values. (Skip pure type conversions, if requested.)
1371  if (!desiredTypes.empty())
1372  return desiredValue;
1373  if (skipPureTypeConversions)
1374  return lastNonMaterialization;
1375  return current;
1376 }
1377 
1380  TypeRange desiredTypes) const {
1381  ValueVector result = lookupOrDefault(from, desiredTypes);
1382  if (result == ValueVector{from} ||
1383  (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
1384  return {};
1385  return result;
1386 }
1387 
1389  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1390 }
1391 
1393  StringRef patternName) {
1394  // Undo any rewrites.
1395  undoRewrites(state.numRewrites, patternName);
1396 
1397  // Pop all of the recorded ignored operations that are no longer valid.
1398  while (ignoredOps.size() != state.numIgnoredOperations)
1399  ignoredOps.pop_back();
1400 
1401  while (replacedOps.size() != state.numReplacedOps)
1402  replacedOps.pop_back();
1403 }
1404 
1405 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1406  StringRef patternName) {
1407  for (auto &rewrite :
1408  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1409  rewrite->rollback();
1410  rewrites.resize(numRewritesToKeep);
1411 }
1412 
1414  StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1415  SmallVector<ValueVector> &remapped) {
1416  remapped.reserve(llvm::size(values));
1417 
1418  for (const auto &it : llvm::enumerate(values)) {
1419  Value operand = it.value();
1420  Type origType = operand.getType();
1421  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1422 
1423  if (!currentTypeConverter) {
1424  // The current pattern does not have a type converter. Pass the most
1425  // recently mapped values, excluding materializations. Materializations
1426  // are intentionally excluded because their presence may depend on other
1427  // patterns. Including materializations would make the lookup fragile
1428  // and unpredictable.
1429  remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1430  /*skipPureTypeConversions=*/true));
1431  continue;
1432  }
1433 
1434  // If there is no legal conversion, fail to match this pattern.
1435  SmallVector<Type, 1> legalTypes;
1436  if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
1437  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1438  diag << "unable to convert type for " << valueDiagTag << " #"
1439  << it.index() << ", type was " << origType;
1440  });
1441  return failure();
1442  }
1443  // If a type is converted to 0 types, there is nothing to do.
1444  if (legalTypes.empty()) {
1445  remapped.push_back({});
1446  continue;
1447  }
1448 
1449  ValueVector repl = lookupOrDefault(operand, legalTypes);
1450  if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1451  // Mapped values have the correct type or there is an existing
1452  // materialization. Or the operand is not mapped at all and has the
1453  // correct type.
1454  remapped.push_back(std::move(repl));
1455  continue;
1456  }
1457 
1458  // Create a materialization for the most recently mapped values.
1459  repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1460  /*skipPureTypeConversions=*/true);
1462  MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1463  /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1464  /*originalType=*/origType, currentTypeConverter);
1465  remapped.push_back(castValues);
1466  }
1467  return success();
1468 }
1469 
1471  // Check to see if this operation is ignored or was replaced.
1472  return wasOpReplaced(op) || ignoredOps.count(op);
1473 }
1474 
1476  // Check to see if this operation was replaced.
1477  return replacedOps.count(op) || erasedOps.count(op);
1478 }
1479 
1480 //===----------------------------------------------------------------------===//
1481 // Type Conversion
1482 //===----------------------------------------------------------------------===//
1483 
1485  Region *region, const TypeConverter &converter,
1486  TypeConverter::SignatureConversion *entryConversion) {
1487  regionToConverter[region] = &converter;
1488  if (region->empty())
1489  return nullptr;
1490 
1491  // Convert the arguments of each non-entry block within the region.
1492  for (Block &block :
1493  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1494  // Compute the signature for the block with the provided converter.
1495  std::optional<TypeConverter::SignatureConversion> conversion =
1496  converter.convertBlockSignature(&block);
1497  if (!conversion)
1498  return failure();
1499  // Convert the block with the computed signature.
1500  applySignatureConversion(&block, &converter, *conversion);
1501  }
1502 
1503  // Convert the entry block. If an entry signature conversion was provided,
1504  // use that one. Otherwise, compute the signature with the type converter.
1505  if (entryConversion)
1506  return applySignatureConversion(&region->front(), &converter,
1507  *entryConversion);
1508  std::optional<TypeConverter::SignatureConversion> conversion =
1509  converter.convertBlockSignature(&region->front());
1510  if (!conversion)
1511  return failure();
1512  return applySignatureConversion(&region->front(), &converter, *conversion);
1513 }
1514 
1516  Block *block, const TypeConverter *converter,
1517  TypeConverter::SignatureConversion &signatureConversion) {
1518 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1519  // A block cannot be converted multiple times.
1520  if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1521  llvm::report_fatal_error("block was already converted");
1522 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1523 
1525 
1526  // If no arguments are being changed or added, there is nothing to do.
1527  unsigned origArgCount = block->getNumArguments();
1528  auto convertedTypes = signatureConversion.getConvertedTypes();
1529  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1530  return block;
1531 
1532  // Compute the locations of all block arguments in the new block.
1533  SmallVector<Location> newLocs(convertedTypes.size(),
1535  for (unsigned i = 0; i < origArgCount; ++i) {
1536  auto inputMap = signatureConversion.getInputMapping(i);
1537  if (!inputMap || inputMap->replacedWithValues())
1538  continue;
1539  Location origLoc = block->getArgument(i).getLoc();
1540  for (unsigned j = 0; j < inputMap->size; ++j)
1541  newLocs[inputMap->inputNo + j] = origLoc;
1542  }
1543 
1544  // Insert a new block with the converted block argument types and move all ops
1545  // from the old block to the new block.
1546  Block *newBlock =
1547  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1548  convertedTypes, newLocs);
1549 
1550  // If a listener is attached to the dialect conversion, ops cannot be moved
1551  // to the destination block in bulk ("fast path"). This is because at the time
1552  // the notifications are sent, it is unknown which ops were moved. Instead,
1553  // ops should be moved one-by-one ("slow path"), so that a separate
1554  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1555  // a bit more efficient, so we try to do that when possible.
1556  bool fastPath = !config.listener;
1557  if (fastPath) {
1559  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1560  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1561  } else {
1562  while (!block->empty())
1563  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1564  }
1565 
1566  // Replace all uses of the old block with the new block.
1567  block->replaceAllUsesWith(newBlock);
1568 
1569  for (unsigned i = 0; i != origArgCount; ++i) {
1570  BlockArgument origArg = block->getArgument(i);
1571  Type origArgType = origArg.getType();
1572 
1573  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1574  signatureConversion.getInputMapping(i);
1575  if (!inputMap) {
1576  // This block argument was dropped and no replacement value was provided.
1577  // Materialize a replacement value "out of thin air".
1578  Value mat =
1580  MaterializationKind::Source,
1581  OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1582  origArg.getLoc(),
1583  /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1584  /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1585  /*isPureTypeConversion=*/false)
1586  .front();
1587  replaceUsesOfBlockArgument(origArg, mat, converter);
1588  continue;
1589  }
1590 
1591  if (inputMap->replacedWithValues()) {
1592  // This block argument was dropped and replacement values were provided.
1593  assert(inputMap->size == 0 &&
1594  "invalid to provide a replacement value when the argument isn't "
1595  "dropped");
1596  replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
1597  converter);
1598  continue;
1599  }
1600 
1601  // This is a 1->1+ mapping.
1602  auto replArgs =
1603  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1604  replaceUsesOfBlockArgument(origArg, replArgs, converter);
1605  }
1606 
1608  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1609 
1610  // Erase the old block. (It is just unlinked for now and will be erased during
1611  // cleanup.)
1612  rewriter.eraseBlock(block);
1613 
1614  return newBlock;
1615 }
1616 
1617 //===----------------------------------------------------------------------===//
1618 // Materializations
1619 //===----------------------------------------------------------------------===//
1620 
1621 /// Build an unresolved materialization operation given an output type and set
1622 /// of input operands.
1624  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1625  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1626  Type originalType, const TypeConverter *converter,
1627  bool isPureTypeConversion) {
1628  assert((!originalType || kind == MaterializationKind::Target) &&
1629  "original type is valid only for target materializations");
1630  assert(TypeRange(inputs) != outputTypes &&
1631  "materialization is not necessary");
1632 
1633  // Create an unresolved materialization. We use a new OpBuilder to avoid
1634  // tracking the materialization like we do for other operations.
1635  OpBuilder builder(outputTypes.front().getContext());
1636  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1637  UnrealizedConversionCastOp convertOp =
1638  UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1640  StringRef kindStr =
1641  kind == MaterializationKind::Source ? "source" : "target";
1642  convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1643  }
1645  convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
1646 
1647  // Register the materialization.
1648  unresolvedMaterializations[convertOp] =
1649  UnresolvedMaterializationInfo(converter, kind, originalType);
1651  if (!valuesToMap.empty())
1652  mapping.map(valuesToMap, convertOp.getResults());
1653  appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1654  std::move(valuesToMap));
1655  } else {
1656  patternMaterializations.insert(convertOp);
1657  }
1658  return convertOp.getResults();
1659 }
1660 
1662  Value value, const TypeConverter *converter) {
1663  assert(config.allowPatternRollback &&
1664  "this code path is valid only in rollback mode");
1665 
1666  // Try to find a replacement value with the same type in the conversion value
1667  // mapping. This includes cached materializations. We try to reuse those
1668  // instead of generating duplicate IR.
1669  ValueVector repl = lookupOrNull(value, value.getType());
1670  if (!repl.empty())
1671  return repl.front();
1672 
1673  // Check if the value is dead. No replacement value is needed in that case.
1674  // This is an approximate check that may have false negatives but does not
1675  // require computing and traversing an inverse mapping. (We may end up
1676  // building source materializations that are never used and that fold away.)
1677  if (llvm::all_of(value.getUsers(),
1678  [&](Operation *op) { return replacedOps.contains(op); }) &&
1679  !mapping.isMappedTo(value))
1680  return Value();
1681 
1682  // No replacement value was found. Get the latest replacement value
1683  // (regardless of the type) and build a source materialization to the
1684  // original type.
1685  repl = lookupOrNull(value);
1686  if (repl.empty()) {
1687  // No replacement value is registered in the mapping. This means that the
1688  // value is dropped and no longer needed. (If the value were still needed,
1689  // a source materialization producing a replacement value "out of thin air"
1690  // would have already been created during `replaceOp` or
1691  // `applySignatureConversion`.)
1692  return Value();
1693  }
1694 
1695  // Note: `computeInsertPoint` computes the "earliest" insertion point at
1696  // which all values in `repl` are defined. It is important to emit the
1697  // materialization at that location because the same materialization may be
1698  // reused in a different context. (That's because materializations are cached
1699  // in the conversion value mapping.) The insertion point of the
1700  // materialization must be valid for all future users that may be created
1701  // later in the conversion process.
1702  Value castValue =
1703  buildUnresolvedMaterialization(MaterializationKind::Source,
1704  computeInsertPoint(repl), value.getLoc(),
1705  /*valuesToMap=*/repl, /*inputs=*/repl,
1706  /*outputTypes=*/value.getType(),
1707  /*originalType=*/Type(), converter)
1708  .front();
1709  return castValue;
1710 }
1711 
1712 //===----------------------------------------------------------------------===//
1713 // Rewriter Notification Hooks
1714 //===----------------------------------------------------------------------===//
1715 
1717  Operation *op, OpBuilder::InsertPoint previous) {
1718  // If no previous insertion point is provided, the op used to be detached.
1719  bool wasDetached = !previous.isSet();
1720  LLVM_DEBUG({
1721  logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
1722  << ")";
1723  if (wasDetached)
1724  logger.getOStream() << " (was detached)";
1725  logger.getOStream() << "\n";
1726  });
1727 
1728  // In rollback mode, it is easier to misuse the API, so perform extra error
1729  // checking.
1730  assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
1731  "attempting to insert into a block within a replaced/erased op");
1732 
1733  // In "no rollback" mode, the listener is always notified immediately.
1735  config.listener->notifyOperationInserted(op, previous);
1736 
1737  if (wasDetached) {
1738  // If the op was detached, it is most likely a newly created op. Add it the
1739  // set of newly created ops, so that it will be legalized. If this op is
1740  // not a newly created op, it will be legalized a second time, which is
1741  // inefficient but harmless.
1742  patternNewOps.insert(op);
1743 
1745  // TODO: If the same op is inserted multiple times from a detached
1746  // state, the rollback mechanism may erase the same op multiple times.
1747  // This is a bug in the rollback-based dialect conversion driver.
1748  appendRewrite<CreateOperationRewrite>(op);
1749  } else {
1750  // In "no rollback" mode, there is an extra data structure for tracking
1751  // erased operations that must be kept up to date.
1752  erasedOps.erase(op);
1753  }
1754  return;
1755  }
1756 
1757  // The op was moved from one place to another.
1759  appendRewrite<MoveOperationRewrite>(op, previous);
1760 }
1761 
1762 /// Given that `fromRange` is about to be replaced with `toRange`, compute
1763 /// replacement values with the types of `fromRange`.
1764 static SmallVector<Value>
1766  const SmallVector<SmallVector<Value>> &toRange,
1767  const TypeConverter *converter) {
1768  assert(!impl.config.allowPatternRollback &&
1769  "this code path is valid only in 'no rollback' mode");
1770  SmallVector<Value> repls;
1771  for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1772  if (from.use_empty()) {
1773  // The replaced value is dead. No replacement value is needed.
1774  repls.push_back(Value());
1775  continue;
1776  }
1777 
1778  if (to.empty()) {
1779  // The replaced value is dropped. Materialize a replacement value "out of
1780  // thin air".
1781  Value srcMat = impl.buildUnresolvedMaterialization(
1782  MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
1783  /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1784  /*outputTypes=*/from.getType(), /*originalType=*/Type(),
1785  converter)[0];
1786  repls.push_back(srcMat);
1787  continue;
1788  }
1789 
1790  if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
1791  // The replacement value already has the correct type. Use it directly.
1792  repls.push_back(to[0]);
1793  continue;
1794  }
1795 
1796  // The replacement value has the wrong type. Build a source materialization
1797  // to the original type.
1798  // TODO: This is a bit inefficient. We should try to reuse existing
1799  // materializations if possible. This would require an extension of the
1800  // `lookupOrDefault` API.
1801  Value srcMat = impl.buildUnresolvedMaterialization(
1802  MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
1803  /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
1804  /*originalType=*/Type(), converter)[0];
1805  repls.push_back(srcMat);
1806  }
1807 
1808  return repls;
1809 }
1810 
1812  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1813  assert(newValues.size() == op->getNumResults() &&
1814  "incorrect number of replacement values");
1815 
1817  // Pattern rollback is not allowed: materialize all IR changes immediately.
1819  *this, op->getResults(), newValues, currentTypeConverter);
1820  // Update internal data structures, so that there are no dangling pointers
1821  // to erased IR.
1822  op->walk([&](Operation *op) {
1823  erasedOps.insert(op);
1824  ignoredOps.remove(op);
1825  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1826  unresolvedMaterializations.erase(castOp);
1827  patternMaterializations.erase(castOp);
1828  }
1829  // The original op will be erased, so remove it from the set of
1830  // unlegalized ops.
1831  if (config.unlegalizedOps)
1832  config.unlegalizedOps->erase(op);
1833  });
1834  op->walk([&](Block *block) { erasedBlocks.insert(block); });
1835  // Replace the op with the replacement values and notify the listener.
1836  notifyingRewriter.replaceOp(op, repls);
1837  return;
1838  }
1839 
1840  assert(!ignoredOps.contains(op) && "operation was already replaced");
1841 
1842  // Check if replaced op is an unresolved materialization, i.e., an
1843  // unrealized_conversion_cast op that was created by the conversion driver.
1844  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1845  // Make sure that the user does not mess with unresolved materializations
1846  // that were inserted by the conversion driver. We keep track of these
1847  // ops in internal data structures.
1848  assert(!unresolvedMaterializations.contains(castOp) &&
1849  "attempting to replace/erase an unresolved materialization");
1850  }
1851 
1852  // Create mappings for each of the new result values.
1853  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
1854  if (repl.empty()) {
1855  // This result was dropped and no replacement value was provided.
1856  // Materialize a replacement value "out of thin air".
1858  MaterializationKind::Source, computeInsertPoint(result),
1859  result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
1860  /*outputTypes=*/result.getType(), /*originalType=*/Type(),
1861  currentTypeConverter, /*isPureTypeConversion=*/false);
1862  continue;
1863  }
1864 
1865  // Remap result to replacement value.
1866  if (repl.empty())
1867  continue;
1868  mapping.map(static_cast<Value>(result), std::move(repl));
1869  }
1870 
1871  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1872  // Mark this operation and all nested ops as replaced.
1873  op->walk([&](Operation *op) { replacedOps.insert(op); });
1874 }
1875 
1877  BlockArgument from, ValueRange to, const TypeConverter *converter) {
1879  SmallVector<Value> toConv = llvm::to_vector(to);
1880  SmallVector<Value> repls =
1881  getReplacementValues(*this, from, {toConv}, converter);
1882  IRRewriter r(from.getContext());
1883  Value repl = repls.front();
1884  if (!repl)
1885  return;
1886 
1887  performReplaceBlockArg(r, from, repl);
1888  return;
1889  }
1890 
1891 #ifndef NDEBUG
1892  // Make sure that a block argument is not replaced multiple times. In
1893  // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
1894  // uses of the given block argument, but also all future uses that may be
1895  // introduced by future pattern applications. Therefore, it does not make
1896  // sense to call `replaceUsesOfBlockArgument` multiple times with the same
1897  // block argument. Doing so would overwrite the mapping and mess with the
1898  // internal state of the dialect conversion driver.
1899  assert(!replacedArgs.contains(from) &&
1900  "attempting to replace a block argument that was already replaced");
1901  replacedArgs.insert(from);
1902 #endif // NDEBUG
1903 
1904  appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
1905  mapping.map(from, to);
1906 }
1907 
1910  // Pattern rollback is not allowed: materialize all IR changes immediately.
1911  // Update internal data structures, so that there are no dangling pointers
1912  // to erased IR.
1913  block->walk([&](Operation *op) {
1914  erasedOps.insert(op);
1915  ignoredOps.remove(op);
1916  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1917  unresolvedMaterializations.erase(castOp);
1918  patternMaterializations.erase(castOp);
1919  }
1920  // The original op will be erased, so remove it from the set of
1921  // unlegalized ops.
1922  if (config.unlegalizedOps)
1923  config.unlegalizedOps->erase(op);
1924  });
1925  block->walk([&](Block *block) { erasedBlocks.insert(block); });
1926  // Erase the block and notify the listener.
1928  return;
1929  }
1930 
1931  assert(!wasOpReplaced(block->getParentOp()) &&
1932  "attempting to erase a block within a replaced/erased op");
1933  appendRewrite<EraseBlockRewrite>(block);
1934 
1935  // Unlink the block from its parent region. The block is kept in the rewrite
1936  // object and will be actually destroyed when rewrites are applied. This
1937  // allows us to keep the operations in the block live and undo the removal by
1938  // re-inserting the block.
1939  block->getParent()->getBlocks().remove(block);
1940 
1941  // Mark all nested ops as erased.
1942  block->walk([&](Operation *op) { replacedOps.insert(op); });
1943 }
1944 
1946  Block *block, Region *previous, Region::iterator previousIt) {
1947  // If no previous insertion point is provided, the block used to be detached.
1948  bool wasDetached = !previous;
1949  Operation *newParentOp = block->getParentOp();
1950  LLVM_DEBUG(
1951  {
1952  Operation *parent = newParentOp;
1953  if (parent) {
1954  logger.startLine() << "** Insert Block into : '" << parent->getName()
1955  << "' (" << parent << ")";
1956  } else {
1957  logger.startLine()
1958  << "** Insert Block into detached Region (nullptr parent op)";
1959  }
1960  if (wasDetached)
1961  logger.getOStream() << " (was detached)";
1962  logger.getOStream() << "\n";
1963  });
1964 
1965  // In rollback mode, it is easier to misuse the API, so perform extra error
1966  // checking.
1967  assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
1968  "attempting to insert into a region within a replaced/erased op");
1969  (void)newParentOp;
1970 
1971  // In "no rollback" mode, the listener is always notified immediately.
1973  config.listener->notifyBlockInserted(block, previous, previousIt);
1974 
1975  patternInsertedBlocks.insert(block);
1976 
1977  if (wasDetached) {
1978  // If the block was detached, it is most likely a newly created block.
1980  // TODO: If the same block is inserted multiple times from a detached
1981  // state, the rollback mechanism may erase the same block multiple times.
1982  // This is a bug in the rollback-based dialect conversion driver.
1983  appendRewrite<CreateBlockRewrite>(block);
1984  } else {
1985  // In "no rollback" mode, there is an extra data structure for tracking
1986  // erased blocks that must be kept up to date.
1987  erasedBlocks.erase(block);
1988  }
1989  return;
1990  }
1991 
1992  // The block was moved from one place to another.
1994  appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
1995 }
1996 
1998  Block *dest,
1999  Block::iterator before) {
2000  appendRewrite<InlineBlockRewrite>(dest, source, before);
2001 }
2002 
2004  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
2005  LLVM_DEBUG({
2007  reasonCallback(diag);
2008  logger.startLine() << "** Failure : " << diag.str() << "\n";
2009  if (config.notifyCallback)
2011  });
2012 }
2013 
2014 //===----------------------------------------------------------------------===//
2015 // ConversionPatternRewriter
2016 //===----------------------------------------------------------------------===//
2017 
2018 ConversionPatternRewriter::ConversionPatternRewriter(
2019  MLIRContext *ctx, const ConversionConfig &config)
2020  : PatternRewriter(ctx),
2021  impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
2022  setListener(impl.get());
2023 }
2024 
2026 
2028  return impl->config;
2029 }
2030 
2032  assert(op && newOp && "expected non-null op");
2033  replaceOp(op, newOp->getResults());
2034 }
2035 
2037  assert(op->getNumResults() == newValues.size() &&
2038  "incorrect # of replacement values");
2039  LLVM_DEBUG({
2040  impl->logger.startLine()
2041  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
2042  });
2043 
2044  // If the current insertion point is before the erased operation, we adjust
2045  // the insertion point to be after the operation.
2046  if (getInsertionPoint() == op->getIterator())
2048 
2050  llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2051  return v ? SmallVector<Value>{v} : SmallVector<Value>();
2052  });
2053  impl->replaceOp(op, std::move(newVals));
2054 }
2055 
2057  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2058  assert(op->getNumResults() == newValues.size() &&
2059  "incorrect # of replacement values");
2060  LLVM_DEBUG({
2061  impl->logger.startLine()
2062  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
2063  });
2064 
2065  // If the current insertion point is before the erased operation, we adjust
2066  // the insertion point to be after the operation.
2067  if (getInsertionPoint() == op->getIterator())
2069 
2070  impl->replaceOp(op, std::move(newValues));
2071 }
2072 
2074  LLVM_DEBUG({
2075  impl->logger.startLine()
2076  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
2077  });
2078 
2079  // If the current insertion point is before the erased operation, we adjust
2080  // the insertion point to be after the operation.
2081  if (getInsertionPoint() == op->getIterator())
2083 
2084  SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
2085  impl->replaceOp(op, std::move(nullRepls));
2086 }
2087 
2089  impl->eraseBlock(block);
2090 }
2091 
2093  Block *block, TypeConverter::SignatureConversion &conversion,
2094  const TypeConverter *converter) {
2095  assert(!impl->wasOpReplaced(block->getParentOp()) &&
2096  "attempting to apply a signature conversion to a block within a "
2097  "replaced/erased op");
2098  return impl->applySignatureConversion(block, converter, conversion);
2099 }
2100 
2102  Region *region, const TypeConverter &converter,
2103  TypeConverter::SignatureConversion *entryConversion) {
2104  assert(!impl->wasOpReplaced(region->getParentOp()) &&
2105  "attempting to apply a signature conversion to a block within a "
2106  "replaced/erased op");
2107  return impl->convertRegionTypes(region, converter, entryConversion);
2108 }
2109 
2111  ValueRange to) {
2112  LLVM_DEBUG({
2113  impl->logger.startLine() << "** Replace Argument : '" << from << "'";
2114  if (Operation *parentOp = from.getOwner()->getParentOp()) {
2115  impl->logger.getOStream() << " (in region of '" << parentOp->getName()
2116  << "' (" << parentOp << ")\n";
2117  } else {
2118  impl->logger.getOStream() << " (unlinked block)\n";
2119  }
2120  });
2121  impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
2122 }
2123 
2125  SmallVector<ValueVector> remappedValues;
2126  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
2127  remappedValues)))
2128  return nullptr;
2129  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
2130  return remappedValues.front().front();
2131 }
2132 
2133 LogicalResult
2135  SmallVectorImpl<Value> &results) {
2136  if (keys.empty())
2137  return success();
2138  SmallVector<ValueVector> remapped;
2139  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
2140  remapped)))
2141  return failure();
2142  for (const auto &values : remapped) {
2143  assert(values.size() == 1 && "1:N conversion not supported");
2144  results.push_back(values.front());
2145  }
2146  return success();
2147 }
2148 
2150  Block::iterator before,
2151  ValueRange argValues) {
2152 #ifndef NDEBUG
2153  assert(argValues.size() == source->getNumArguments() &&
2154  "incorrect # of argument replacement values");
2155  assert(!impl->wasOpReplaced(source->getParentOp()) &&
2156  "attempting to inline a block from a replaced/erased op");
2157  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
2158  "attempting to inline a block into a replaced/erased op");
2159  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
2160  // The source block will be deleted, so it should not have any users (i.e.,
2161  // there should be no predecessors).
2162  assert(llvm::all_of(source->getUsers(), opIgnored) &&
2163  "expected 'source' to have no predecessors");
2164 #endif // NDEBUG
2165 
2166  // If a listener is attached to the dialect conversion, ops cannot be moved
2167  // to the destination block in bulk ("fast path"). This is because at the time
2168  // the notifications are sent, it is unknown which ops were moved. Instead,
2169  // ops should be moved one-by-one ("slow path"), so that a separate
2170  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
2171  // a bit more efficient, so we try to do that when possible.
2172  bool fastPath = !getConfig().listener;
2173 
2174  if (fastPath && impl->config.allowPatternRollback)
2175  impl->inlineBlockBefore(source, dest, before);
2176 
2177  // Replace all uses of block arguments.
2178  for (auto it : llvm::zip(source->getArguments(), argValues))
2179  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
2180 
2181  if (fastPath) {
2182  // Move all ops at once.
2183  dest->getOperations().splice(before, source->getOperations());
2184  } else {
2185  // Move op by op.
2186  while (!source->empty())
2187  moveOpBefore(&source->front(), dest, before);
2188  }
2189 
2190  // If the current insertion point is within the source block, adjust the
2191  // insertion point to the destination block.
2192  if (getInsertionBlock() == source)
2193  setInsertionPoint(dest, getInsertionPoint());
2194 
2195  // Erase the source block.
2196  eraseBlock(source);
2197 }
2198 
2200  if (!impl->config.allowPatternRollback) {
2201  // Pattern rollback is not allowed: no extra bookkeeping is needed.
2203  return;
2204  }
2205  assert(!impl->wasOpReplaced(op) &&
2206  "attempting to modify a replaced/erased op");
2207 #ifndef NDEBUG
2208  impl->pendingRootUpdates.insert(op);
2209 #endif
2210  impl->appendRewrite<ModifyOperationRewrite>(op);
2211 }
2212 
2214  impl->patternModifiedOps.insert(op);
2215  if (!impl->config.allowPatternRollback) {
2217  if (getConfig().listener)
2218  getConfig().listener->notifyOperationModified(op);
2219  return;
2220  }
2221 
2222  // There is nothing to do here, we only need to track the operation at the
2223  // start of the update.
2224 #ifndef NDEBUG
2225  assert(!impl->wasOpReplaced(op) &&
2226  "attempting to modify a replaced/erased op");
2227  assert(impl->pendingRootUpdates.erase(op) &&
2228  "operation did not have a pending in-place update");
2229 #endif
2230 }
2231 
2233  if (!impl->config.allowPatternRollback) {
2235  return;
2236  }
2237 #ifndef NDEBUG
2238  assert(impl->pendingRootUpdates.erase(op) &&
2239  "operation did not have a pending in-place update");
2240 #endif
2241  // Erase the last update for this operation.
2242  auto it = llvm::find_if(
2243  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
2244  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2245  return modifyRewrite && modifyRewrite->getOperation() == op;
2246  });
2247  assert(it != impl->rewrites.rend() && "no root update started on op");
2248  (*it)->rollback();
2249  int updateIdx = std::prev(impl->rewrites.rend()) - it;
2250  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2251 }
2252 
2254  return *impl;
2255 }
2256 
2257 //===----------------------------------------------------------------------===//
2258 // ConversionPattern
2259 //===----------------------------------------------------------------------===//
2260 
2262  ArrayRef<ValueRange> operands) const {
2263  SmallVector<Value> oneToOneOperands;
2264  oneToOneOperands.reserve(operands.size());
2265  for (ValueRange operand : operands) {
2266  if (operand.size() != 1)
2267  return failure();
2268 
2269  oneToOneOperands.push_back(operand.front());
2270  }
2271  return std::move(oneToOneOperands);
2272 }
2273 
2274 LogicalResult
2276  PatternRewriter &rewriter) const {
2277  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
2278  auto &rewriterImpl = dialectRewriter.getImpl();
2279 
2280  // Track the current conversion pattern type converter in the rewriter.
2281  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
2282  getTypeConverter());
2283 
2284  // Remap the operands of the operation.
2285  SmallVector<ValueVector> remapped;
2286  if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
2287  op->getOperands(), remapped))) {
2288  return failure();
2289  }
2290  SmallVector<ValueRange> remappedAsRange =
2291  llvm::to_vector_of<ValueRange>(remapped);
2292  return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2293 }
2294 
2295 //===----------------------------------------------------------------------===//
2296 // OperationLegalizer
2297 //===----------------------------------------------------------------------===//
2298 
2299 namespace {
2300 /// A set of rewrite patterns that can be used to legalize a given operation.
2301 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2302 
2303 /// This class defines a recursive operation legalizer.
2304 class OperationLegalizer {
2305 public:
2306  using LegalizationAction = ConversionTarget::LegalizationAction;
2307 
2308  OperationLegalizer(ConversionPatternRewriter &rewriter,
2309  const ConversionTarget &targetInfo,
2311 
2312  /// Returns true if the given operation is known to be illegal on the target.
2313  bool isIllegal(Operation *op) const;
2314 
2315  /// Attempt to legalize the given operation. Returns success if the operation
2316  /// was legalized, failure otherwise.
2317  LogicalResult legalize(Operation *op);
2318 
2319  /// Returns the conversion target in use by the legalizer.
2320  const ConversionTarget &getTarget() { return target; }
2321 
2322 private:
2323  /// Attempt to legalize the given operation by folding it.
2324  LogicalResult legalizeWithFold(Operation *op);
2325 
2326  /// Attempt to legalize the given operation by applying a pattern. Returns
2327  /// success if the operation was legalized, failure otherwise.
2328  LogicalResult legalizeWithPattern(Operation *op);
2329 
2330  /// Return true if the given pattern may be applied to the given operation,
2331  /// false otherwise.
2332  bool canApplyPattern(Operation *op, const Pattern &pattern);
2333 
2334  /// Legalize the resultant IR after successfully applying the given pattern.
2335  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
2336  const RewriterState &curState,
2337  const SetVector<Operation *> &newOps,
2338  const SetVector<Operation *> &modifiedOps,
2339  const SetVector<Block *> &insertedBlocks);
2340 
2341  /// Legalizes the actions registered during the execution of a pattern.
2342  LogicalResult
2343  legalizePatternBlockRewrites(Operation *op,
2344  const SetVector<Block *> &insertedBlocks,
2345  const SetVector<Operation *> &newOps);
2346  LogicalResult
2347  legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
2348  LogicalResult
2349  legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
2350 
2351  //===--------------------------------------------------------------------===//
2352  // Cost Model
2353  //===--------------------------------------------------------------------===//
2354 
2355  /// Build an optimistic legalization graph given the provided patterns. This
2356  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2357  /// patterns for operations that are not directly legal, but may be
2358  /// transitively legal for the current target given the provided patterns.
2359  void buildLegalizationGraph(
2360  LegalizationPatterns &anyOpLegalizerPatterns,
2362 
2363  /// Compute the benefit of each node within the computed legalization graph.
2364  /// This orders the patterns within 'legalizerPatterns' based upon two
2365  /// criteria:
2366  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2367  /// represent the more direct mapping to the target.
2368  /// 2) When comparing patterns with the same legalization depth, prefer the
2369  /// pattern with the highest PatternBenefit. This allows for users to
2370  /// prefer specific legalizations over others.
2371  void computeLegalizationGraphBenefit(
2372  LegalizationPatterns &anyOpLegalizerPatterns,
2374 
2375  /// Compute the legalization depth when legalizing an operation of the given
2376  /// type.
2377  unsigned computeOpLegalizationDepth(
2378  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2380 
2381  /// Apply the conversion cost model to the given set of patterns, and return
2382  /// the smallest legalization depth of any of the patterns. See
2383  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2384  unsigned applyCostModelToPatterns(
2385  LegalizationPatterns &patterns,
2386  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2388 
2389  /// The current set of patterns that have been applied.
2390  SmallPtrSet<const Pattern *, 8> appliedPatterns;
2391 
2392  /// The rewriter to use when converting operations.
2393  ConversionPatternRewriter &rewriter;
2394 
2395  /// The legalization information provided by the target.
2396  const ConversionTarget &target;
2397 
2398  /// The pattern applicator to use for conversions.
2399  PatternApplicator applicator;
2400 };
2401 } // namespace
2402 
2403 OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2404  const ConversionTarget &targetInfo,
2406  : rewriter(rewriter), target(targetInfo), applicator(patterns) {
2407  // The set of patterns that can be applied to illegal operations to transform
2408  // them into legal ones.
2410  LegalizationPatterns anyOpLegalizerPatterns;
2411 
2412  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2413  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2414 }
2415 
2416 bool OperationLegalizer::isIllegal(Operation *op) const {
2417  return target.isIllegal(op);
2418 }
2419 
2420 LogicalResult OperationLegalizer::legalize(Operation *op) {
2421 #ifndef NDEBUG
2422  const char *logLineComment =
2423  "//===-------------------------------------------===//\n";
2424 
2425  auto &logger = rewriter.getImpl().logger;
2426 #endif
2427 
2428  // Check to see if the operation is ignored and doesn't need to be converted.
2429  bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2430 
2431  LLVM_DEBUG({
2432  logger.getOStream() << "\n";
2433  logger.startLine() << logLineComment;
2434  logger.startLine() << "Legalizing operation : ";
2435  // Do not print the operation name if the operation is ignored. Ignored ops
2436  // may have been erased and should not be accessed. The pointer can be
2437  // printed safely.
2438  if (!isIgnored)
2439  logger.getOStream() << "'" << op->getName() << "' ";
2440  logger.getOStream() << "(" << op << ") {\n";
2441  logger.indent();
2442 
2443  // If the operation has no regions, just print it here.
2444  if (!isIgnored && op->getNumRegions() == 0) {
2445  logger.startLine() << OpWithFlags(op,
2446  OpPrintingFlags().printGenericOpForm())
2447  << "\n";
2448  }
2449  });
2450 
2451  if (isIgnored) {
2452  LLVM_DEBUG({
2453  logSuccess(logger, "operation marked 'ignored' during conversion");
2454  logger.startLine() << logLineComment;
2455  });
2456  return success();
2457  }
2458 
2459  // Check if this operation is legal on the target.
2460  if (auto legalityInfo = target.isLegal(op)) {
2461  LLVM_DEBUG({
2462  logSuccess(
2463  logger, "operation marked legal by the target{0}",
2464  legalityInfo->isRecursivelyLegal
2465  ? "; NOTE: operation is recursively legal; skipping internals"
2466  : "");
2467  logger.startLine() << logLineComment;
2468  });
2469 
2470  // If this operation is recursively legal, mark its children as ignored so
2471  // that we don't consider them for legalization.
2472  if (legalityInfo->isRecursivelyLegal) {
2473  op->walk([&](Operation *nested) {
2474  if (op != nested)
2475  rewriter.getImpl().ignoredOps.insert(nested);
2476  });
2477  }
2478 
2479  return success();
2480  }
2481 
2482  // If the operation is not legal, try to fold it in-place if the folding mode
2483  // is 'BeforePatterns'. 'Never' will skip this.
2484  const ConversionConfig &config = rewriter.getConfig();
2486  if (succeeded(legalizeWithFold(op))) {
2487  LLVM_DEBUG({
2488  logSuccess(logger, "operation was folded");
2489  logger.startLine() << logLineComment;
2490  });
2491  return success();
2492  }
2493  }
2494 
2495  // Otherwise, we need to apply a legalization pattern to this operation.
2496  if (succeeded(legalizeWithPattern(op))) {
2497  LLVM_DEBUG({
2498  logSuccess(logger, "");
2499  logger.startLine() << logLineComment;
2500  });
2501  return success();
2502  }
2503 
2504  // If the operation can't be legalized via patterns, try to fold it in-place
2505  // if the folding mode is 'AfterPatterns'.
2507  if (succeeded(legalizeWithFold(op))) {
2508  LLVM_DEBUG({
2509  logSuccess(logger, "operation was folded");
2510  logger.startLine() << logLineComment;
2511  });
2512  return success();
2513  }
2514  }
2515 
2516  LLVM_DEBUG({
2517  logFailure(logger, "no matched legalization pattern");
2518  logger.startLine() << logLineComment;
2519  });
2520  return failure();
2521 }
2522 
2523 /// Helper function that moves and returns the given object. Also resets the
2524 /// original object, so that it is in a valid, empty state again.
2525 template <typename T>
2526 static T moveAndReset(T &obj) {
2527  T result = std::move(obj);
2528  obj = T();
2529  return result;
2530 }
2531 
2532 LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2533  auto &rewriterImpl = rewriter.getImpl();
2534  LLVM_DEBUG({
2535  rewriterImpl.logger.startLine() << "* Fold {\n";
2536  rewriterImpl.logger.indent();
2537  });
2538 
2539  // Clear pattern state, so that the next pattern application starts with a
2540  // clean slate. (The op/block sets are populated by listener notifications.)
2541  auto cleanup = llvm::make_scope_exit([&]() {
2542  rewriterImpl.patternNewOps.clear();
2543  rewriterImpl.patternModifiedOps.clear();
2544  rewriterImpl.patternInsertedBlocks.clear();
2545  });
2546 
2547  // Upon failure, undo all changes made by the folder.
2548  RewriterState curState = rewriterImpl.getCurrentState();
2549 
2550  // Try to fold the operation.
2551  StringRef opName = op->getName().getStringRef();
2552  SmallVector<Value, 2> replacementValues;
2553  SmallVector<Operation *, 2> newOps;
2554  rewriter.setInsertionPoint(op);
2555  rewriter.startOpModification(op);
2556  if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2557  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2558  rewriter.cancelOpModification(op);
2559  return failure();
2560  }
2561  rewriter.finalizeOpModification(op);
2562 
2563  // An empty list of replacement values indicates that the fold was in-place.
2564  // As the operation changed, a new legalization needs to be attempted.
2565  if (replacementValues.empty())
2566  return legalize(op);
2567 
2568  // Insert a replacement for 'op' with the folded replacement values.
2569  rewriter.replaceOp(op, replacementValues);
2570 
2571  // Recursively legalize any new constant operations.
2572  for (Operation *newOp : newOps) {
2573  if (failed(legalize(newOp))) {
2574  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2575  "failed to legalize generated constant '{0}'",
2576  newOp->getName()));
2577  if (!rewriter.getConfig().allowPatternRollback) {
2578  // Rolling back a folder is like rolling back a pattern.
2579  llvm::report_fatal_error(
2580  "op '" + opName +
2581  "' folder rollback of IR modifications requested");
2582  }
2583  rewriterImpl.resetState(
2584  curState, std::string(op->getName().getStringRef()) + " folder");
2585  return failure();
2586  }
2587  }
2588 
2589  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2590  return success();
2591 }
2592 
2593 /// Report a fatal error indicating that newly produced or modified IR could
2594 /// not be legalized.
2595 static void
2597  const SetVector<Operation *> &newOps,
2598  const SetVector<Operation *> &modifiedOps,
2599  const SetVector<Block *> &insertedBlocks) {
2600  auto newOpNames = llvm::map_range(
2601  newOps, [](Operation *op) { return op->getName().getStringRef(); });
2602  auto modifiedOpNames = llvm::map_range(
2603  modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2604  StringRef detachedBlockStr = "(detached block)";
2605  auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
2606  if (block->getParentOp())
2607  return block->getParentOp()->getName().getStringRef();
2608  return detachedBlockStr;
2609  });
2610  llvm::report_fatal_error(
2611  "pattern '" + pattern.getDebugName() +
2612  "' produced IR that could not be legalized. " + "new ops: {" +
2613  llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
2614  llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
2615  llvm::join(insertedBlockNames, ", ") + "}");
2616 }
2617 
2618 LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2619  auto &rewriterImpl = rewriter.getImpl();
2620  const ConversionConfig &config = rewriter.getConfig();
2621 
2622 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2623  Operation *checkOp;
2624  std::optional<OperationFingerPrint> topLevelFingerPrint;
2625  if (!rewriterImpl.config.allowPatternRollback) {
2626  // The op may be getting erased, so we have to check the parent op.
2627  // (In rare cases, a pattern may even erase the parent op, which will cause
2628  // a crash here. Expensive checks are "best effort".) Skip the check if the
2629  // op does not have a parent op.
2630  if ((checkOp = op->getParentOp())) {
2631  if (!op->getContext()->isMultithreadingEnabled()) {
2632  topLevelFingerPrint = OperationFingerPrint(checkOp);
2633  } else {
2634  // Another thread may be modifying a sibling operation. Therefore, the
2635  // fingerprinting mechanism of the parent op works only in
2636  // single-threaded mode.
2637  LLVM_DEBUG({
2638  rewriterImpl.logger.startLine()
2639  << "WARNING: Multi-threadeding is enabled. Some dialect "
2640  "conversion expensive checks are skipped in multithreading "
2641  "mode!\n";
2642  });
2643  }
2644  }
2645  }
2646 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2647 
2648  // Functor that returns if the given pattern may be applied.
2649  auto canApply = [&](const Pattern &pattern) {
2650  bool canApply = canApplyPattern(op, pattern);
2651  if (canApply && config.listener)
2652  config.listener->notifyPatternBegin(pattern, op);
2653  return canApply;
2654  };
2655 
2656  // Functor that cleans up the rewriter state after a pattern failed to match.
2657  RewriterState curState = rewriterImpl.getCurrentState();
2658  auto onFailure = [&](const Pattern &pattern) {
2659  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2660  if (!rewriterImpl.config.allowPatternRollback) {
2661  // Erase all unresolved materializations.
2662  for (auto op : rewriterImpl.patternMaterializations) {
2663  rewriterImpl.unresolvedMaterializations.erase(op);
2664  op.erase();
2665  }
2666  rewriterImpl.patternMaterializations.clear();
2667 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2668  // Expensive pattern check that can detect API violations.
2669  if (checkOp) {
2670  OperationFingerPrint fingerPrintAfterPattern(checkOp);
2671  if (fingerPrintAfterPattern != *topLevelFingerPrint)
2672  llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2673  "' returned failure but IR did change");
2674  }
2675 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2676  }
2677  rewriterImpl.patternNewOps.clear();
2678  rewriterImpl.patternModifiedOps.clear();
2679  rewriterImpl.patternInsertedBlocks.clear();
2680  LLVM_DEBUG({
2681  logFailure(rewriterImpl.logger, "pattern failed to match");
2682  if (rewriterImpl.config.notifyCallback) {
2684  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2685  << "\" on op:\n"
2686  << *op;
2687  rewriterImpl.config.notifyCallback(diag);
2688  }
2689  });
2690  if (config.listener)
2691  config.listener->notifyPatternEnd(pattern, failure());
2692  rewriterImpl.resetState(curState, pattern.getDebugName());
2693  appliedPatterns.erase(&pattern);
2694  };
2695 
2696  // Functor that performs additional legalization when a pattern is
2697  // successfully applied.
2698  auto onSuccess = [&](const Pattern &pattern) {
2699  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2700  if (!rewriterImpl.config.allowPatternRollback) {
2701  // Eagerly erase unused materializations.
2702  for (auto op : rewriterImpl.patternMaterializations) {
2703  if (op->use_empty()) {
2704  rewriterImpl.unresolvedMaterializations.erase(op);
2705  op.erase();
2706  }
2707  }
2708  rewriterImpl.patternMaterializations.clear();
2709  }
2710  SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2711  SetVector<Operation *> modifiedOps =
2712  moveAndReset(rewriterImpl.patternModifiedOps);
2713  SetVector<Block *> insertedBlocks =
2714  moveAndReset(rewriterImpl.patternInsertedBlocks);
2715  auto result = legalizePatternResult(op, pattern, curState, newOps,
2716  modifiedOps, insertedBlocks);
2717  appliedPatterns.erase(&pattern);
2718  if (failed(result)) {
2719  if (!rewriterImpl.config.allowPatternRollback)
2720  reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
2721  insertedBlocks);
2722  rewriterImpl.resetState(curState, pattern.getDebugName());
2723  }
2724  if (config.listener)
2725  config.listener->notifyPatternEnd(pattern, result);
2726  return result;
2727  };
2728 
2729  // Try to match and rewrite a pattern on this operation.
2730  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2731  onSuccess);
2732 }
2733 
2734 bool OperationLegalizer::canApplyPattern(Operation *op,
2735  const Pattern &pattern) {
2736  LLVM_DEBUG({
2737  auto &os = rewriter.getImpl().logger;
2738  os.getOStream() << "\n";
2739  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2740  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2741  os.getOStream() << ")' {\n";
2742  os.indent();
2743  });
2744 
2745  // Ensure that we don't cycle by not allowing the same pattern to be
2746  // applied twice in the same recursion stack if it is not known to be safe.
2747  if (!pattern.hasBoundedRewriteRecursion() &&
2748  !appliedPatterns.insert(&pattern).second) {
2749  LLVM_DEBUG(
2750  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2751  return false;
2752  }
2753  return true;
2754 }
2755 
2756 LogicalResult OperationLegalizer::legalizePatternResult(
2757  Operation *op, const Pattern &pattern, const RewriterState &curState,
2758  const SetVector<Operation *> &newOps,
2759  const SetVector<Operation *> &modifiedOps,
2760  const SetVector<Block *> &insertedBlocks) {
2761  [[maybe_unused]] auto &impl = rewriter.getImpl();
2762  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2763 
2764 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2765  // Check that the root was either replaced or updated in place.
2766  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2767  auto replacedRoot = [&] {
2768  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2769  };
2770  auto updatedRootInPlace = [&] {
2771  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2772  };
2773  if (!replacedRoot() && !updatedRootInPlace())
2774  llvm::report_fatal_error(
2775  "expected pattern to replace the root operation or modify it in place");
2776 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2777 
2778  // Legalize each of the actions registered during application.
2779  if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
2780  failed(legalizePatternRootUpdates(modifiedOps)) ||
2781  failed(legalizePatternCreatedOperations(newOps))) {
2782  return failure();
2783  }
2784 
2785  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2786  return success();
2787 }
2788 
2789 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2790  Operation *op, const SetVector<Block *> &insertedBlocks,
2791  const SetVector<Operation *> &newOps) {
2792  ConversionPatternRewriterImpl &impl = rewriter.getImpl();
2793  SmallPtrSet<Operation *, 16> alreadyLegalized;
2794 
2795  // If the pattern moved or created any blocks, make sure the types of block
2796  // arguments get legalized.
2797  for (Block *block : insertedBlocks) {
2798  if (impl.erasedBlocks.contains(block))
2799  continue;
2800 
2801  // Only check blocks outside of the current operation.
2802  Operation *parentOp = block->getParentOp();
2803  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2804  continue;
2805 
2806  // If the region of the block has a type converter, try to convert the block
2807  // directly.
2808  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2809  std::optional<TypeConverter::SignatureConversion> conversion =
2810  converter->convertBlockSignature(block);
2811  if (!conversion) {
2812  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2813  "block"));
2814  return failure();
2815  }
2816  impl.applySignatureConversion(block, converter, *conversion);
2817  continue;
2818  }
2819 
2820  // Otherwise, try to legalize the parent operation if it was not generated
2821  // by this pattern. This is because we will attempt to legalize the parent
2822  // operation, and blocks in regions created by this pattern will already be
2823  // legalized later on.
2824  if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
2825  if (failed(legalize(parentOp))) {
2826  LLVM_DEBUG(logFailure(
2827  impl.logger, "operation '{0}'({1}) became illegal after rewrite",
2828  parentOp->getName(), parentOp));
2829  return failure();
2830  }
2831  }
2832  }
2833  return success();
2834 }
2835 
2836 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2837  const SetVector<Operation *> &newOps) {
2838  for (Operation *op : newOps) {
2839  if (failed(legalize(op))) {
2840  LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
2841  "failed to legalize generated operation '{0}'({1})",
2842  op->getName(), op));
2843  return failure();
2844  }
2845  }
2846  return success();
2847 }
2848 
2849 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2850  const SetVector<Operation *> &modifiedOps) {
2851  for (Operation *op : modifiedOps) {
2852  if (failed(legalize(op))) {
2853  LLVM_DEBUG(
2854  logFailure(rewriter.getImpl().logger,
2855  "failed to legalize operation updated in-place '{0}'",
2856  op->getName()));
2857  return failure();
2858  }
2859  }
2860  return success();
2861 }
2862 
2863 //===----------------------------------------------------------------------===//
2864 // Cost Model
2865 //===----------------------------------------------------------------------===//
2866 
2867 void OperationLegalizer::buildLegalizationGraph(
2868  LegalizationPatterns &anyOpLegalizerPatterns,
2869  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2870  // A mapping between an operation and a set of operations that can be used to
2871  // generate it.
2873  // A mapping between an operation and any currently invalid patterns it has.
2875  // A worklist of patterns to consider for legality.
2876  SetVector<const Pattern *> patternWorklist;
2877 
2878  // Build the mapping from operations to the parent ops that may generate them.
2879  applicator.walkAllPatterns([&](const Pattern &pattern) {
2880  std::optional<OperationName> root = pattern.getRootKind();
2881 
2882  // If the pattern has no specific root, we can't analyze the relationship
2883  // between the root op and generated operations. Given that, add all such
2884  // patterns to the legalization set.
2885  if (!root) {
2886  anyOpLegalizerPatterns.push_back(&pattern);
2887  return;
2888  }
2889 
2890  // Skip operations that are always known to be legal.
2891  if (target.getOpAction(*root) == LegalizationAction::Legal)
2892  return;
2893 
2894  // Add this pattern to the invalid set for the root op and record this root
2895  // as a parent for any generated operations.
2896  invalidPatterns[*root].insert(&pattern);
2897  for (auto op : pattern.getGeneratedOps())
2898  parentOps[op].insert(*root);
2899 
2900  // Add this pattern to the worklist.
2901  patternWorklist.insert(&pattern);
2902  });
2903 
2904  // If there are any patterns that don't have a specific root kind, we can't
2905  // make direct assumptions about what operations will never be legalized.
2906  // Note: Technically we could, but it would require an analysis that may
2907  // recurse into itself. It would be better to perform this kind of filtering
2908  // at a higher level than here anyways.
2909  if (!anyOpLegalizerPatterns.empty()) {
2910  for (const Pattern *pattern : patternWorklist)
2911  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2912  return;
2913  }
2914 
2915  while (!patternWorklist.empty()) {
2916  auto *pattern = patternWorklist.pop_back_val();
2917 
2918  // Check to see if any of the generated operations are invalid.
2919  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2920  std::optional<LegalizationAction> action = target.getOpAction(op);
2921  return !legalizerPatterns.count(op) &&
2922  (!action || action == LegalizationAction::Illegal);
2923  }))
2924  continue;
2925 
2926  // Otherwise, if all of the generated operation are valid, this op is now
2927  // legal so add all of the child patterns to the worklist.
2928  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2929  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2930 
2931  // Add any invalid patterns of the parent operations to see if they have now
2932  // become legal.
2933  for (auto op : parentOps[*pattern->getRootKind()])
2934  patternWorklist.set_union(invalidPatterns[op]);
2935  }
2936 }
2937 
2938 void OperationLegalizer::computeLegalizationGraphBenefit(
2939  LegalizationPatterns &anyOpLegalizerPatterns,
2940  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2941  // The smallest pattern depth, when legalizing an operation.
2942  DenseMap<OperationName, unsigned> minOpPatternDepth;
2943 
2944  // For each operation that is transitively legal, compute a cost for it.
2945  for (auto &opIt : legalizerPatterns)
2946  if (!minOpPatternDepth.count(opIt.first))
2947  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2948  legalizerPatterns);
2949 
2950  // Apply the cost model to the patterns that can match any operation. Those
2951  // with a specific operation type are already resolved when computing the op
2952  // legalization depth.
2953  if (!anyOpLegalizerPatterns.empty())
2954  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2955  legalizerPatterns);
2956 
2957  // Apply a cost model to the pattern applicator. We order patterns first by
2958  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2959  // decreasing benefit.
2960  applicator.applyCostModel([&](const Pattern &pattern) {
2961  ArrayRef<const Pattern *> orderedPatternList;
2962  if (std::optional<OperationName> rootName = pattern.getRootKind())
2963  orderedPatternList = legalizerPatterns[*rootName];
2964  else
2965  orderedPatternList = anyOpLegalizerPatterns;
2966 
2967  // If the pattern is not found, then it was removed and cannot be matched.
2968  auto *it = llvm::find(orderedPatternList, &pattern);
2969  if (it == orderedPatternList.end())
2971 
2972  // Patterns found earlier in the list have higher benefit.
2973  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2974  });
2975 }
2976 
2977 unsigned OperationLegalizer::computeOpLegalizationDepth(
2978  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2979  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2980  // Check for existing depth.
2981  auto depthIt = minOpPatternDepth.find(op);
2982  if (depthIt != minOpPatternDepth.end())
2983  return depthIt->second;
2984 
2985  // If a mapping for this operation does not exist, then this operation
2986  // is always legal. Return 0 as the depth for a directly legal operation.
2987  auto opPatternsIt = legalizerPatterns.find(op);
2988  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2989  return 0u;
2990 
2991  // Record this initial depth in case we encounter this op again when
2992  // recursively computing the depth.
2993  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2994 
2995  // Apply the cost model to the operation patterns, and update the minimum
2996  // depth.
2997  unsigned minDepth = applyCostModelToPatterns(
2998  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2999  minOpPatternDepth[op] = minDepth;
3000  return minDepth;
3001 }
3002 
3003 unsigned OperationLegalizer::applyCostModelToPatterns(
3004  LegalizationPatterns &patterns,
3005  DenseMap<OperationName, unsigned> &minOpPatternDepth,
3006  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
3007  unsigned minDepth = std::numeric_limits<unsigned>::max();
3008 
3009  // Compute the depth for each pattern within the set.
3010  SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3011  patternsByDepth.reserve(patterns.size());
3012  for (const Pattern *pattern : patterns) {
3013  unsigned depth = 1;
3014  for (auto generatedOp : pattern->getGeneratedOps()) {
3015  unsigned generatedOpDepth = computeOpLegalizationDepth(
3016  generatedOp, minOpPatternDepth, legalizerPatterns);
3017  depth = std::max(depth, generatedOpDepth + 1);
3018  }
3019  patternsByDepth.emplace_back(pattern, depth);
3020 
3021  // Update the minimum depth of the pattern list.
3022  minDepth = std::min(minDepth, depth);
3023  }
3024 
3025  // If the operation only has one legalization pattern, there is no need to
3026  // sort them.
3027  if (patternsByDepth.size() == 1)
3028  return minDepth;
3029 
3030  // Sort the patterns by those likely to be the most beneficial.
3031  llvm::stable_sort(patternsByDepth,
3032  [](const std::pair<const Pattern *, unsigned> &lhs,
3033  const std::pair<const Pattern *, unsigned> &rhs) {
3034  // First sort by the smaller pattern legalization
3035  // depth.
3036  if (lhs.second != rhs.second)
3037  return lhs.second < rhs.second;
3038 
3039  // Then sort by the larger pattern benefit.
3040  auto lhsBenefit = lhs.first->getBenefit();
3041  auto rhsBenefit = rhs.first->getBenefit();
3042  return lhsBenefit > rhsBenefit;
3043  });
3044 
3045  // Update the legalization pattern to use the new sorted list.
3046  patterns.clear();
3047  for (auto &patternIt : patternsByDepth)
3048  patterns.push_back(patternIt.first);
3049  return minDepth;
3050 }
3051 
3052 //===----------------------------------------------------------------------===//
3053 // OperationConverter
3054 //===----------------------------------------------------------------------===//
3055 namespace {
3056 enum OpConversionMode {
3057  /// In this mode, the conversion will ignore failed conversions to allow
3058  /// illegal operations to co-exist in the IR.
3059  Partial,
3060 
3061  /// In this mode, all operations must be legal for the given target for the
3062  /// conversion to succeed.
3063  Full,
3064 
3065  /// In this mode, operations are analyzed for legality. No actual rewrites are
3066  /// applied to the operations on success.
3067  Analysis,
3068 };
3069 } // namespace
3070 
3071 namespace mlir {
3072 // This class converts operations to a given conversion target via a set of
3073 // rewrite patterns. The conversion behaves differently depending on the
3074 // conversion mode.
3076  explicit OperationConverter(MLIRContext *ctx, const ConversionTarget &target,
3078  const ConversionConfig &config,
3079  OpConversionMode mode)
3080  : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
3081  mode(mode) {}
3082 
3083  /// Converts the given operations to the conversion target.
3084  LogicalResult convertOperations(ArrayRef<Operation *> ops);
3085 
3086 private:
3087  /// Converts an operation with the given rewriter.
3088  LogicalResult convert(Operation *op);
3089 
3090  /// The rewriter to use when converting operations.
3091  ConversionPatternRewriter rewriter;
3092 
3093  /// The legalizer to use when converting operations.
3094  OperationLegalizer opLegalizer;
3095 
3096  /// The conversion mode to use when legalizing operations.
3097  OpConversionMode mode;
3098 };
3099 } // namespace mlir
3100 
3101 LogicalResult OperationConverter::convert(Operation *op) {
3102  const ConversionConfig &config = rewriter.getConfig();
3103 
3104  // Legalize the given operation.
3105  if (failed(opLegalizer.legalize(op))) {
3106  // Handle the case of a failed conversion for each of the different modes.
3107  // Full conversions expect all operations to be converted.
3108  if (mode == OpConversionMode::Full)
3109  return op->emitError()
3110  << "failed to legalize operation '" << op->getName() << "'";
3111  // Partial conversions allow conversions to fail iff the operation was not
3112  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
3113  // set, non-legalizable ops are added to that set.
3114  if (mode == OpConversionMode::Partial) {
3115  if (opLegalizer.isIllegal(op))
3116  return op->emitError()
3117  << "failed to legalize operation '" << op->getName()
3118  << "' that was explicitly marked illegal";
3119  if (config.unlegalizedOps)
3120  config.unlegalizedOps->insert(op);
3121  }
3122  } else if (mode == OpConversionMode::Analysis) {
3123  // Analysis conversions don't fail if any operations fail to legalize,
3124  // they are only interested in the operations that were successfully
3125  // legalized.
3126  if (config.legalizableOps)
3127  config.legalizableOps->insert(op);
3128  }
3129  return success();
3130 }
3131 
3132 static LogicalResult
3134  UnrealizedConversionCastOp op,
3135  const UnresolvedMaterializationInfo &info) {
3136  assert(!op.use_empty() &&
3137  "expected that dead materializations have already been DCE'd");
3138  Operation::operand_range inputOperands = op.getOperands();
3139 
3140  // Try to materialize the conversion.
3141  if (const TypeConverter *converter = info.getConverter()) {
3142  rewriter.setInsertionPoint(op);
3143  SmallVector<Value> newMaterialization;
3144  switch (info.getMaterializationKind()) {
3145  case MaterializationKind::Target:
3146  newMaterialization = converter->materializeTargetConversion(
3147  rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3148  info.getOriginalType());
3149  break;
3150  case MaterializationKind::Source:
3151  assert(op->getNumResults() == 1 && "expected single result");
3152  Value sourceMat = converter->materializeSourceConversion(
3153  rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3154  if (sourceMat)
3155  newMaterialization.push_back(sourceMat);
3156  break;
3157  }
3158  if (!newMaterialization.empty()) {
3159 #ifndef NDEBUG
3160  ValueRange newMaterializationRange(newMaterialization);
3161  assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
3162  "materialization callback produced value of incorrect type");
3163 #endif // NDEBUG
3164  rewriter.replaceOp(op, newMaterialization);
3165  return success();
3166  }
3167  }
3168 
3169  InFlightDiagnostic diag = op->emitError()
3170  << "failed to legalize unresolved materialization "
3171  "from ("
3172  << inputOperands.getTypes() << ") to ("
3173  << op.getResultTypes()
3174  << ") that remained live after conversion";
3175  diag.attachNote(op->getUsers().begin()->getLoc())
3176  << "see existing live user here: " << *op->getUsers().begin();
3177  return failure();
3178 }
3179 
3181  const ConversionTarget &target = opLegalizer.getTarget();
3182 
3183  // Compute the set of operations and blocks to convert.
3184  SmallVector<Operation *> toConvert;
3185  for (auto *op : ops) {
3187  [&](Operation *op) {
3188  toConvert.push_back(op);
3189  // Don't check this operation's children for conversion if the
3190  // operation is recursively legal.
3191  auto legalityInfo = target.isLegal(op);
3192  if (legalityInfo && legalityInfo->isRecursivelyLegal)
3193  return WalkResult::skip();
3194  return WalkResult::advance();
3195  });
3196  }
3197 
3198  // Convert each operation and discard rewrites on failure.
3199  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3200 
3201  for (auto *op : toConvert) {
3202  if (failed(convert(op))) {
3203  // Dialect conversion failed.
3204  if (rewriterImpl.config.allowPatternRollback) {
3205  // Rollback is allowed: restore the original IR.
3206  rewriterImpl.undoRewrites();
3207  } else {
3208  // Rollback is not allowed: apply all modifications that have been
3209  // performed so far.
3210  rewriterImpl.applyRewrites();
3211  }
3212  return failure();
3213  }
3214  }
3215 
3216  // After a successful conversion, apply rewrites.
3217  rewriterImpl.applyRewrites();
3218 
3219  // Gather all unresolved materializations.
3222  &materializations = rewriterImpl.unresolvedMaterializations;
3223  for (auto it : materializations)
3224  allCastOps.push_back(it.first);
3225 
3226  // Reconcile all UnrealizedConversionCastOps that were inserted by the
3227  // dialect conversion frameworks. (Not the one that were inserted by
3228  // patterns.)
3229  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
3230  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
3231 
3232  // Drop markers.
3233  for (UnrealizedConversionCastOp castOp : remainingCastOps)
3234  castOp->removeAttr(kPureTypeConversionMarker);
3235 
3236  // Try to legalize all unresolved materializations.
3237  if (rewriter.getConfig().buildMaterializations) {
3238  // Use a new rewriter, so the modifications are not tracked for rollback
3239  // purposes etc.
3240  IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
3241  rewriter.getConfig().listener);
3242  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3243  auto it = materializations.find(castOp);
3244  assert(it != materializations.end() && "inconsistent state");
3245  if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
3246  it->second)))
3247  return failure();
3248  }
3249  }
3250 
3251  return success();
3252 }
3253 
3254 //===----------------------------------------------------------------------===//
3255 // Reconcile Unrealized Casts
3256 //===----------------------------------------------------------------------===//
3257 
3260  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3261  SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3262  // This set is maintained only if `remainingCastOps` is provided.
3263  DenseSet<Operation *> erasedOps;
3264 
3265  // Helper function that adds all operands to the worklist that are an
3266  // unrealized_conversion_cast op result.
3267  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
3268  for (Value v : castOp.getInputs())
3269  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3270  worklist.insert(inputCastOp);
3271  };
3272 
3273  // Helper function that return the unrealized_conversion_cast op that
3274  // defines all inputs of the given op (in the same order). Return "nullptr"
3275  // if there is no such op.
3276  auto getInputCast =
3277  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3278  if (castOp.getInputs().empty())
3279  return {};
3280  auto inputCastOp =
3281  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3282  if (!inputCastOp)
3283  return {};
3284  if (inputCastOp.getOutputs() != castOp.getInputs())
3285  return {};
3286  return inputCastOp;
3287  };
3288 
3289  // Process ops in the worklist bottom-to-top.
3290  while (!worklist.empty()) {
3291  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3292  if (castOp->use_empty()) {
3293  // DCE: If the op has no users, erase it. Add the operands to the
3294  // worklist to find additional DCE opportunities.
3295  enqueueOperands(castOp);
3296  if (remainingCastOps)
3297  erasedOps.insert(castOp.getOperation());
3298  castOp->erase();
3299  continue;
3300  }
3301 
3302  // Traverse the chain of input cast ops to see if an op with the same
3303  // input types can be found.
3304  UnrealizedConversionCastOp nextCast = castOp;
3305  while (nextCast) {
3306  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3307  // Found a cast where the input types match the output types of the
3308  // matched op. We can directly use those inputs and the matched op can
3309  // be removed.
3310  enqueueOperands(castOp);
3311  castOp.replaceAllUsesWith(nextCast.getInputs());
3312  if (remainingCastOps)
3313  erasedOps.insert(castOp.getOperation());
3314  castOp->erase();
3315  break;
3316  }
3317  nextCast = getInputCast(nextCast);
3318  }
3319  }
3320 
3321  if (remainingCastOps)
3322  for (UnrealizedConversionCastOp op : castOps)
3323  if (!erasedOps.contains(op.getOperation()))
3324  remainingCastOps->push_back(op);
3325 }
3326 
3327 //===----------------------------------------------------------------------===//
3328 // Type Conversion
3329 //===----------------------------------------------------------------------===//
3330 
3332  ArrayRef<Type> types) {
3333  assert(!types.empty() && "expected valid types");
3334  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3335  addInputs(types);
3336 }
3337 
3339  assert(!types.empty() &&
3340  "1->0 type remappings don't need to be added explicitly");
3341  argTypes.append(types.begin(), types.end());
3342 }
3343 
3344 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3345  unsigned newInputNo,
3346  unsigned newInputCount) {
3347  assert(!remappedInputs[origInputNo] && "input has already been remapped");
3348  assert(newInputCount != 0 && "expected valid input count");
3349  remappedInputs[origInputNo] =
3350  InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
3351 }
3352 
3354  unsigned origInputNo, ArrayRef<Value> replacements) {
3355  assert(!remappedInputs[origInputNo] && "input has already been remapped");
3356  remappedInputs[origInputNo] = InputMapping{
3357  origInputNo, /*size=*/0,
3358  SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3359 }
3360 
3362  SmallVectorImpl<Type> &results) const {
3363  assert(t && "expected non-null type");
3364 
3365  {
3366  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3367  std::defer_lock);
3369  cacheReadLock.lock();
3370  auto existingIt = cachedDirectConversions.find(t);
3371  if (existingIt != cachedDirectConversions.end()) {
3372  if (existingIt->second)
3373  results.push_back(existingIt->second);
3374  return success(existingIt->second != nullptr);
3375  }
3376  auto multiIt = cachedMultiConversions.find(t);
3377  if (multiIt != cachedMultiConversions.end()) {
3378  results.append(multiIt->second.begin(), multiIt->second.end());
3379  return success();
3380  }
3381  }
3382  // Walk the added converters in reverse order to apply the most recently
3383  // registered first.
3384  size_t currentCount = results.size();
3385 
3386  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3387  std::defer_lock);
3388 
3389  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
3390  if (std::optional<LogicalResult> result = converter(t, results)) {
3392  cacheWriteLock.lock();
3393  if (!succeeded(*result)) {
3394  assert(results.size() == currentCount &&
3395  "failed type conversion should not change results");
3396  cachedDirectConversions.try_emplace(t, nullptr);
3397  return failure();
3398  }
3399  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3400  if (newTypes.size() == 1)
3401  cachedDirectConversions.try_emplace(t, newTypes.front());
3402  else
3403  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3404  return success();
3405  } else {
3406  assert(results.size() == currentCount &&
3407  "failed type conversion should not change results");
3408  }
3409  }
3410  return failure();
3411 }
3412 
3414  SmallVectorImpl<Type> &results) const {
3415  assert(v && "expected non-null value");
3416 
3417  // If this type converter does not have context-aware type conversions, call
3418  // the type-based overload, which has caching.
3419  if (!hasContextAwareTypeConversions)
3420  return convertType(v.getType(), results);
3421 
3422  // Walk the added converters in reverse order to apply the most recently
3423  // registered first.
3424  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
3425  if (std::optional<LogicalResult> result = converter(v, results)) {
3426  if (!succeeded(*result))
3427  return failure();
3428  return success();
3429  }
3430  }
3431  return failure();
3432 }
3433 
3435  // Use the multi-type result version to convert the type.
3436  SmallVector<Type, 1> results;
3437  if (failed(convertType(t, results)))
3438  return nullptr;
3439 
3440  // Check to ensure that only one type was produced.
3441  return results.size() == 1 ? results.front() : nullptr;
3442 }
3443 
3445  // Use the multi-type result version to convert the type.
3446  SmallVector<Type, 1> results;
3447  if (failed(convertType(v, results)))
3448  return nullptr;
3449 
3450  // Check to ensure that only one type was produced.
3451  return results.size() == 1 ? results.front() : nullptr;
3452 }
3453 
3454 LogicalResult
3456  SmallVectorImpl<Type> &results) const {
3457  for (Type type : types)
3458  if (failed(convertType(type, results)))
3459  return failure();
3460  return success();
3461 }
3462 
3463 LogicalResult
3465  SmallVectorImpl<Type> &results) const {
3466  for (Value value : values)
3467  if (failed(convertType(value, results)))
3468  return failure();
3469  return success();
3470 }
3471 
3472 bool TypeConverter::isLegal(Type type) const {
3473  return convertType(type) == type;
3474 }
3475 
3476 bool TypeConverter::isLegal(Value value) const {
3477  return convertType(value) == value.getType();
3478 }
3479 
3481  return isLegal(op->getOperands()) && isLegal(op->getResults());
3482 }
3483 
3484 bool TypeConverter::isLegal(Region *region) const {
3485  return llvm::all_of(
3486  *region, [this](Block &block) { return isLegal(block.getArguments()); });
3487 }
3488 
3489 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3490  if (!isLegal(ty.getInputs()))
3491  return false;
3492  if (!isLegal(ty.getResults()))
3493  return false;
3494  return true;
3495 }
3496 
3497 LogicalResult
3499  SignatureConversion &result) const {
3500  // Try to convert the given input type.
3501  SmallVector<Type, 1> convertedTypes;
3502  if (failed(convertType(type, convertedTypes)))
3503  return failure();
3504 
3505  // If this argument is being dropped, there is nothing left to do.
3506  if (convertedTypes.empty())
3507  return success();
3508 
3509  // Otherwise, add the new inputs.
3510  result.addInputs(inputNo, convertedTypes);
3511  return success();
3512 }
3513 LogicalResult
3515  SignatureConversion &result,
3516  unsigned origInputOffset) const {
3517  for (unsigned i = 0, e = types.size(); i != e; ++i)
3518  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3519  return failure();
3520  return success();
3521 }
3522 LogicalResult
3524  SignatureConversion &result) const {
3525  // Try to convert the given input type.
3526  SmallVector<Type, 1> convertedTypes;
3527  if (failed(convertType(value, convertedTypes)))
3528  return failure();
3529 
3530  // If this argument is being dropped, there is nothing left to do.
3531  if (convertedTypes.empty())
3532  return success();
3533 
3534  // Otherwise, add the new inputs.
3535  result.addInputs(inputNo, convertedTypes);
3536  return success();
3537 }
3538 LogicalResult
3540  SignatureConversion &result,
3541  unsigned origInputOffset) const {
3542  for (unsigned i = 0, e = values.size(); i != e; ++i)
3543  if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
3544  return failure();
3545  return success();
3546 }
3547 
3549  Location loc, Type resultType,
3550  ValueRange inputs) const {
3551  for (const SourceMaterializationCallbackFn &fn :
3552  llvm::reverse(sourceMaterializations))
3553  if (Value result = fn(builder, resultType, inputs, loc))
3554  return result;
3555  return nullptr;
3556 }
3557 
3559  Location loc, Type resultType,
3560  ValueRange inputs,
3561  Type originalType) const {
3563  builder, loc, TypeRange(resultType), inputs, originalType);
3564  if (result.empty())
3565  return nullptr;
3566  assert(result.size() == 1 && "expected single result");
3567  return result.front();
3568 }
3569 
3571  OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3572  Type originalType) const {
3573  for (const TargetMaterializationCallbackFn &fn :
3574  llvm::reverse(targetMaterializations)) {
3575  SmallVector<Value> result =
3576  fn(builder, resultTypes, inputs, loc, originalType);
3577  if (result.empty())
3578  continue;
3579  assert(TypeRange(ValueRange(result)) == resultTypes &&
3580  "callback produced incorrect number of values or values with "
3581  "incorrect types");
3582  return result;
3583  }
3584  return {};
3585 }
3586 
3587 std::optional<TypeConverter::SignatureConversion>
3589  SignatureConversion conversion(block->getNumArguments());
3590  if (failed(convertSignatureArgs(block->getArguments(), conversion)))
3591  return std::nullopt;
3592  return conversion;
3593 }
3594 
3595 //===----------------------------------------------------------------------===//
3596 // Type attribute conversion
3597 //===----------------------------------------------------------------------===//
3600  return AttributeConversionResult(attr, resultTag);
3601 }
3602 
3605  return AttributeConversionResult(nullptr, naTag);
3606 }
3607 
3610  return AttributeConversionResult(nullptr, abortTag);
3611 }
3612 
3614  return impl.getInt() == resultTag;
3615 }
3616 
3618  return impl.getInt() == naTag;
3619 }
3620 
3622  return impl.getInt() == abortTag;
3623 }
3624 
3626  assert(hasResult() && "Cannot get result from N/A or abort");
3627  return impl.getPointer();
3628 }
3629 
3630 std::optional<Attribute>
3632  for (const TypeAttributeConversionCallbackFn &fn :
3633  llvm::reverse(typeAttributeConversions)) {
3634  AttributeConversionResult res = fn(type, attr);
3635  if (res.hasResult())
3636  return res.getResult();
3637  if (res.isAbort())
3638  return std::nullopt;
3639  }
3640  return std::nullopt;
3641 }
3642 
3643 //===----------------------------------------------------------------------===//
3644 // FunctionOpInterfaceSignatureConversion
3645 //===----------------------------------------------------------------------===//
3646 
3647 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3648  const TypeConverter &typeConverter,
3649  ConversionPatternRewriter &rewriter) {
3650  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3651  if (!type)
3652  return failure();
3653 
3654  // Convert the original function types.
3655  TypeConverter::SignatureConversion result(type.getNumInputs());
3656  SmallVector<Type, 1> newResults;
3657  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3658  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3659  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3660  typeConverter, &result)))
3661  return failure();
3662 
3663  // Update the function signature in-place.
3664  auto newType = FunctionType::get(rewriter.getContext(),
3665  result.getConvertedTypes(), newResults);
3666 
3667  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3668 
3669  return success();
3670 }
3671 
3672 /// Create a default conversion pattern that rewrites the type signature of a
3673 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3674 /// represent their type.
3675 namespace {
3676 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3677  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3678  MLIRContext *ctx,
3679  const TypeConverter &converter)
3680  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3681 
3682  LogicalResult
3683  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3684  ConversionPatternRewriter &rewriter) const override {
3685  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3686  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3687  }
3688 };
3689 
3690 struct AnyFunctionOpInterfaceSignatureConversion
3691  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3693 
3694  LogicalResult
3695  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3696  ConversionPatternRewriter &rewriter) const override {
3697  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3698  }
3699 };
3700 } // namespace
3701 
3702 FailureOr<Operation *>
3704  const TypeConverter &converter,
3705  ConversionPatternRewriter &rewriter) {
3706  assert(op && "Invalid op");
3707  Location loc = op->getLoc();
3708  if (converter.isLegal(op))
3709  return rewriter.notifyMatchFailure(loc, "op already legal");
3710 
3711  OperationState newOp(loc, op->getName());
3712  newOp.addOperands(operands);
3713 
3714  SmallVector<Type> newResultTypes;
3715  if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
3716  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3717 
3718  newOp.addTypes(newResultTypes);
3719  newOp.addAttributes(op->getAttrs());
3720  return rewriter.create(newOp);
3721 }
3722 
3724  StringRef functionLikeOpName, RewritePatternSet &patterns,
3725  const TypeConverter &converter) {
3726  patterns.add<FunctionOpInterfaceSignatureConversion>(
3727  functionLikeOpName, patterns.getContext(), converter);
3728 }
3729 
3731  RewritePatternSet &patterns, const TypeConverter &converter) {
3732  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3733  converter, patterns.getContext());
3734 }
3735 
3736 //===----------------------------------------------------------------------===//
3737 // ConversionTarget
3738 //===----------------------------------------------------------------------===//
3739 
3741  LegalizationAction action) {
3742  legalOperations[op].action = action;
3743 }
3744 
3746  LegalizationAction action) {
3747  for (StringRef dialect : dialectNames)
3748  legalDialects[dialect] = action;
3749 }
3750 
3752  -> std::optional<LegalizationAction> {
3753  std::optional<LegalizationInfo> info = getOpInfo(op);
3754  return info ? info->action : std::optional<LegalizationAction>();
3755 }
3756 
3758  -> std::optional<LegalOpDetails> {
3759  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3760  if (!info)
3761  return std::nullopt;
3762 
3763  // Returns true if this operation instance is known to be legal.
3764  auto isOpLegal = [&] {
3765  // Handle dynamic legality either with the provided legality function.
3766  if (info->action == LegalizationAction::Dynamic) {
3767  std::optional<bool> result = info->legalityFn(op);
3768  if (result)
3769  return *result;
3770  }
3771 
3772  // Otherwise, the operation is only legal if it was marked 'Legal'.
3773  return info->action == LegalizationAction::Legal;
3774  };
3775  if (!isOpLegal())
3776  return std::nullopt;
3777 
3778  // This operation is legal, compute any additional legality information.
3779  LegalOpDetails legalityDetails;
3780  if (info->isRecursivelyLegal) {
3781  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3782  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3783  legalityDetails.isRecursivelyLegal =
3784  legalityFnIt->second(op).value_or(true);
3785  } else {
3786  legalityDetails.isRecursivelyLegal = true;
3787  }
3788  }
3789  return legalityDetails;
3790 }
3791 
3793  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3794  if (!info)
3795  return false;
3796 
3797  if (info->action == LegalizationAction::Dynamic) {
3798  std::optional<bool> result = info->legalityFn(op);
3799  if (!result)
3800  return false;
3801 
3802  return !(*result);
3803  }
3804 
3805  return info->action == LegalizationAction::Illegal;
3806 }
3807 
3811  if (!oldCallback)
3812  return newCallback;
3813 
3814  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3815  Operation *op) -> std::optional<bool> {
3816  if (std::optional<bool> result = newCl(op))
3817  return *result;
3818 
3819  return oldCl(op);
3820  };
3821  return chain;
3822 }
3823 
3824 void ConversionTarget::setLegalityCallback(
3825  OperationName name, const DynamicLegalityCallbackFn &callback) {
3826  assert(callback && "expected valid legality callback");
3827  auto *infoIt = legalOperations.find(name);
3828  assert(infoIt != legalOperations.end() &&
3829  infoIt->second.action == LegalizationAction::Dynamic &&
3830  "expected operation to already be marked as dynamically legal");
3831  infoIt->second.legalityFn =
3832  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3833 }
3834 
3836  OperationName name, const DynamicLegalityCallbackFn &callback) {
3837  auto *infoIt = legalOperations.find(name);
3838  assert(infoIt != legalOperations.end() &&
3839  infoIt->second.action != LegalizationAction::Illegal &&
3840  "expected operation to already be marked as legal");
3841  infoIt->second.isRecursivelyLegal = true;
3842  if (callback)
3843  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3844  std::move(opRecursiveLegalityFns[name]), callback);
3845  else
3846  opRecursiveLegalityFns.erase(name);
3847 }
3848 
3849 void ConversionTarget::setLegalityCallback(
3850  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3851  assert(callback && "expected valid legality callback");
3852  for (StringRef dialect : dialects)
3853  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3854  std::move(dialectLegalityFns[dialect]), callback);
3855 }
3856 
3857 void ConversionTarget::setLegalityCallback(
3858  const DynamicLegalityCallbackFn &callback) {
3859  assert(callback && "expected valid legality callback");
3860  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3861 }
3862 
3863 auto ConversionTarget::getOpInfo(OperationName op) const
3864  -> std::optional<LegalizationInfo> {
3865  // Check for info for this specific operation.
3866  const auto *it = legalOperations.find(op);
3867  if (it != legalOperations.end())
3868  return it->second;
3869  // Check for info for the parent dialect.
3870  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3871  if (dialectIt != legalDialects.end()) {
3872  DynamicLegalityCallbackFn callback;
3873  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3874  if (dialectFn != dialectLegalityFns.end())
3875  callback = dialectFn->second;
3876  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3877  callback};
3878  }
3879  // Otherwise, check if we mark unknown operations as dynamic.
3880  if (unknownLegalityFn)
3881  return LegalizationInfo{LegalizationAction::Dynamic,
3882  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3883  return std::nullopt;
3884 }
3885 
3886 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3887 //===----------------------------------------------------------------------===//
3888 // PDL Configuration
3889 //===----------------------------------------------------------------------===//
3890 
3892  auto &rewriterImpl =
3893  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3894  rewriterImpl.currentTypeConverter = getTypeConverter();
3895 }
3896 
3898  auto &rewriterImpl =
3899  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3900  rewriterImpl.currentTypeConverter = nullptr;
3901 }
3902 
3903 /// Remap the given value using the rewriter and the type converter in the
3904 /// provided config.
3905 static FailureOr<SmallVector<Value>>
3907  SmallVector<Value> mappedValues;
3908  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3909  return failure();
3910  return std::move(mappedValues);
3911 }
3912 
3914  patterns.getPDLPatterns().registerRewriteFunction(
3915  "convertValue",
3916  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3917  auto results = pdllConvertValues(
3918  static_cast<ConversionPatternRewriter &>(rewriter), value);
3919  if (failed(results))
3920  return failure();
3921  return results->front();
3922  });
3923  patterns.getPDLPatterns().registerRewriteFunction(
3924  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3925  return pdllConvertValues(
3926  static_cast<ConversionPatternRewriter &>(rewriter), values);
3927  });
3928  patterns.getPDLPatterns().registerRewriteFunction(
3929  "convertType",
3930  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3931  auto &rewriterImpl =
3932  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3933  if (const TypeConverter *converter =
3934  rewriterImpl.currentTypeConverter) {
3935  if (Type newType = converter->convertType(type))
3936  return newType;
3937  return failure();
3938  }
3939  return type;
3940  });
3941  patterns.getPDLPatterns().registerRewriteFunction(
3942  "convertTypes",
3943  [](PatternRewriter &rewriter,
3944  TypeRange types) -> FailureOr<SmallVector<Type>> {
3945  auto &rewriterImpl =
3946  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3947  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3948  if (!converter)
3949  return SmallVector<Type>(types);
3950 
3951  SmallVector<Type> remappedTypes;
3952  if (failed(converter->convertTypes(types, remappedTypes)))
3953  return failure();
3954  return std::move(remappedTypes);
3955  });
3956 }
3957 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3958 
3959 //===----------------------------------------------------------------------===//
3960 // Op Conversion Entry Points
3961 //===----------------------------------------------------------------------===//
3962 
3963 /// This is the type of Action that is dispatched when a conversion is applied.
3965  : public tracing::ActionImpl<ApplyConversionAction> {
3966 public:
3969  static constexpr StringLiteral tag = "apply-conversion";
3970  static constexpr StringLiteral desc =
3971  "Encapsulate the application of a dialect conversion";
3972 
3973  void print(raw_ostream &os) const override { os << tag; }
3974 };
3975 
3976 static LogicalResult applyConversion(ArrayRef<Operation *> ops,
3977  const ConversionTarget &target,
3980  OpConversionMode mode) {
3981  if (ops.empty())
3982  return success();
3983  MLIRContext *ctx = ops.front()->getContext();
3984  LogicalResult status = success();
3985  SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
3987  [&] {
3988  OperationConverter opConverter(ops.front()->getContext(), target,
3989  patterns, config, mode);
3990  status = opConverter.convertOperations(ops);
3991  },
3992  irUnits);
3993  return status;
3994 }
3995 
3996 //===----------------------------------------------------------------------===//
3997 // Partial Conversion
3998 //===----------------------------------------------------------------------===//
3999 
4001  ArrayRef<Operation *> ops, const ConversionTarget &target,
4003  return applyConversion(ops, target, patterns, config,
4004  OpConversionMode::Partial);
4005 }
4006 LogicalResult
4010  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
4011 }
4012 
4013 //===----------------------------------------------------------------------===//
4014 // Full Conversion
4015 //===----------------------------------------------------------------------===//
4016 
4018  const ConversionTarget &target,
4021  return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
4022 }
4024  const ConversionTarget &target,
4027  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
4028 }
4029 
4030 //===----------------------------------------------------------------------===//
4031 // Analysis Conversion
4032 //===----------------------------------------------------------------------===//
4033 
4034 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
4035 /// op is a top-level module op (which is expected to be isolated from above),
4036 /// return that op.
4038  // Check if there is a top-level operation within `ops`. If so, return that
4039  // op.
4040  for (Operation *op : ops) {
4041  if (!op->getParentOp()) {
4042 #ifndef NDEBUG
4043  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
4044  "expected top-level op to be isolated from above");
4045  for (Operation *other : ops)
4046  assert(op->isAncestor(other) &&
4047  "expected ops to have a common ancestor");
4048 #endif // NDEBUG
4049  return op;
4050  }
4051  }
4052 
4053  // No top-level op. Find a common ancestor.
4054  Operation *commonAncestor =
4055  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
4056  for (Operation *op : ops.drop_front()) {
4057  while (!commonAncestor->isProperAncestor(op)) {
4058  commonAncestor =
4060  assert(commonAncestor &&
4061  "expected to find a common isolated from above ancestor");
4062  }
4063  }
4064 
4065  return commonAncestor;
4066 }
4067 
4071 #ifndef NDEBUG
4072  if (config.legalizableOps)
4073  assert(config.legalizableOps->empty() && "expected empty set");
4074 #endif // NDEBUG
4075 
4076  // Clone closted common ancestor that is isolated from above.
4077  Operation *commonAncestor = findCommonAncestor(ops);
4078  IRMapping mapping;
4079  Operation *clonedAncestor = commonAncestor->clone(mapping);
4080  // Compute inverse IR mapping.
4081  DenseMap<Operation *, Operation *> inverseOperationMap;
4082  for (auto &it : mapping.getOperationMap())
4083  inverseOperationMap[it.second] = it.first;
4084 
4085  // Convert the cloned operations. The original IR will remain unchanged.
4086  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4087  ops, [&](Operation *op) { return mapping.lookup(op); });
4088  LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
4089  OpConversionMode::Analysis);
4090 
4091  // Remap `legalizableOps`, so that they point to the original ops and not the
4092  // cloned ops.
4093  if (config.legalizableOps) {
4094  DenseSet<Operation *> originalLegalizableOps;
4095  for (Operation *op : *config.legalizableOps)
4096  originalLegalizableOps.insert(inverseOperationMap[op]);
4097  *config.legalizableOps = std::move(originalLegalizableOps);
4098  }
4099 
4100  // Erase the cloned IR.
4101  clonedAncestor->erase();
4102  return status;
4103 }
4104 
4105 LogicalResult
4109  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
4110 }
static void setInsertionPointAfter(OpBuilder &b, Value value)
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value >> &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg, Value repl)
#define DEBUG_TYPE
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps, const SetVector< Block * > &insertedBlocks)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:94
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:308
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
UnitAttr getUnitAttr()
Definition: Builders.cpp:93
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:24
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
const ConversionConfig & getConfig() const
Return the configuration of the current dialect conversion.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={}) override
PatternRewriter hook for inlining the ops of a block into another block.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to)
Replace all the uses of the block argument from with to.
Base class for the conversion patterns.
FailureOr< SmallVector< Value > > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:32
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:274
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:325
Block::iterator getPoint() const
Definition: Builders.h:338
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:335
Block * getBlock() const
Definition: Builders.h:337
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:468
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:608
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:416
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getTypes() const
Definition: ValueRange.cpp:28
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:304
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:718
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:606
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:236
result_range getResults()
Definition: Operation.h:415
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:896
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:365
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:140
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
Definition: PatternMatch.h:622
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:612
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
CRTP Implementation of an action.
Definition: Action.h:76
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
@ AfterPatterns
Only attempt to fold not legal operations after applying patterns.
@ BeforePatterns
Only attempt to fold not legal operations before applying patterns.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
bool attachDebugMaterializationKind
If set to "true", the materialization kind ("source" or "target") will be attached to "builtin....
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:306
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:296
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:369
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:393
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:377
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:366
This struct represents a range of new types or a range of values that remaps an existing signature in...
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to, const TypeConverter *converter)
Replace the given block argument with the given values.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
DenseSet< BlockArgument > replacedArgs
A set of replaced block arguments.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config)
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Block * > patternInsertedBlocks
A set of blocks that were inserted (newly-created blocks or moved blocks) by the current pattern.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
void replaceOp(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the results of the given operation with the given values and erase the operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.