MLIR  21.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Iterators.h"
19 #include "llvm/ADT/ScopeExit.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/SaveAndRestore.h"
25 #include "llvm/Support/ScopedPrinter.h"
26 #include <optional>
27 
28 using namespace mlir;
29 using namespace mlir::detail;
30 
31 #define DEBUG_TYPE "dialect-conversion"
32 
33 /// A utility function to log a successful result for the given reason.
34 template <typename... Args>
35 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
36  LLVM_DEBUG({
37  os.unindent();
38  os.startLine() << "} -> SUCCESS";
39  if (!fmt.empty())
40  os.getOStream() << " : "
41  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
42  os.getOStream() << "\n";
43  });
44 }
45 
46 /// A utility function to log a failure result for the given reason.
47 template <typename... Args>
48 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
49  LLVM_DEBUG({
50  os.unindent();
51  os.startLine() << "} -> FAILURE : "
52  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
53  << "\n";
54  });
55 }
56 
57 /// Helper function that computes an insertion point where the given value is
58 /// defined and can be used without a dominance violation.
60  Block *insertBlock = value.getParentBlock();
61  Block::iterator insertPt = insertBlock->begin();
62  if (OpResult inputRes = dyn_cast<OpResult>(value))
63  insertPt = ++inputRes.getOwner()->getIterator();
64  return OpBuilder::InsertPoint(insertBlock, insertPt);
65 }
66 
67 /// Helper function that computes an insertion point where the given values are
68 /// defined and can be used without a dominance violation.
70  assert(!vals.empty() && "expected at least one value");
71  DominanceInfo domInfo;
72  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
73  for (Value v : vals.drop_front()) {
74  // Choose the "later" insertion point.
76  if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
77  nextPt.getPoint())) {
78  // pt is before nextPt => choose nextPt.
79  pt = nextPt;
80  } else {
81 #ifndef NDEBUG
82  // nextPt should be before pt => choose pt.
83  // If pt, nextPt are no dominance relationship, then there is no valid
84  // insertion point at which all given values are defined.
85  bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
86  pt.getBlock(), pt.getPoint());
87  assert(dom && "unable to find valid insertion point");
88 #endif // NDEBUG
89  }
90  }
91  return pt;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // ConversionValueMapping
96 //===----------------------------------------------------------------------===//
97 
98 /// A vector of SSA values, optimized for the most common case of a single
99 /// value.
100 using ValueVector = SmallVector<Value, 1>;
101 
102 namespace {
103 
104 /// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
105 struct ValueVectorMapInfo {
106  static ValueVector getEmptyKey() { return ValueVector{Value()}; }
107  static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
108  static ::llvm::hash_code getHashValue(const ValueVector &val) {
109  return ::llvm::hash_combine_range(val);
110  }
111  static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
112  return LHS == RHS;
113  }
114 };
115 
116 /// This class wraps a IRMapping to provide recursive lookup
117 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
118 struct ConversionValueMapping {
119  /// Return "true" if an SSA value is mapped to the given value. May return
120  /// false positives.
121  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
122 
123  /// Lookup the most recently mapped values with the desired types in the
124  /// mapping.
125  ///
126  /// Special cases:
127  /// - If the desired type range is empty, simply return the most recently
128  /// mapped values.
129  /// - If there is no mapping to the desired types, also return the most
130  /// recently mapped values.
131  /// - If there is no mapping for the given values at all, return the given
132  /// value.
133  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
134 
135  /// Lookup the given value within the map, or return an empty vector if the
136  /// value is not mapped. If it is mapped, this follows the same behavior
137  /// as `lookupOrDefault`.
138  ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
139 
140  template <typename T>
141  struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
142 
143  /// Map a value vector to the one provided.
144  template <typename OldVal, typename NewVal>
145  std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
146  map(OldVal &&oldVal, NewVal &&newVal) {
147  LLVM_DEBUG({
148  ValueVector next(newVal);
149  while (true) {
150  assert(next != oldVal && "inserting cyclic mapping");
151  auto it = mapping.find(next);
152  if (it == mapping.end())
153  break;
154  next = it->second;
155  }
156  });
157  mappedTo.insert_range(newVal);
158 
159  mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
160  }
161 
162  /// Map a value vector or single value to the one provided.
163  template <typename OldVal, typename NewVal>
164  std::enable_if_t<!IsValueVector<OldVal>::value ||
165  !IsValueVector<NewVal>::value>
166  map(OldVal &&oldVal, NewVal &&newVal) {
167  if constexpr (IsValueVector<OldVal>{}) {
168  map(std::forward<OldVal>(oldVal), ValueVector{newVal});
169  } else if constexpr (IsValueVector<NewVal>{}) {
170  map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
171  } else {
172  map(ValueVector{oldVal}, ValueVector{newVal});
173  }
174  }
175 
176  void map(Value oldVal, SmallVector<Value> &&newVal) {
177  map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
178  }
179 
180  /// Drop the last mapping for the given values.
181  void erase(const ValueVector &value) { mapping.erase(value); }
182 
183 private:
184  /// Current value mappings.
186 
187  /// All SSA values that are mapped to. May contain false positives.
188  DenseSet<Value> mappedTo;
189 };
190 } // namespace
191 
193 ConversionValueMapping::lookupOrDefault(Value from,
194  TypeRange desiredTypes) const {
195  // Try to find the deepest values that have the desired types. If there is no
196  // such mapping, simply return the deepest values.
197  ValueVector desiredValue;
198  ValueVector current{from};
199  do {
200  // Store the current value if the types match.
201  if (TypeRange(ValueRange(current)) == desiredTypes)
202  desiredValue = current;
203 
204  // If possible, Replace each value with (one or multiple) mapped values.
205  ValueVector next;
206  for (Value v : current) {
207  auto it = mapping.find({v});
208  if (it != mapping.end()) {
209  llvm::append_range(next, it->second);
210  } else {
211  next.push_back(v);
212  }
213  }
214  if (next != current) {
215  // If at least one value was replaced, continue the lookup from there.
216  current = std::move(next);
217  continue;
218  }
219 
220  // Otherwise: Check if there is a mapping for the entire vector. Such
221  // mappings are materializations. (N:M mapping are not supported for value
222  // replacements.)
223  //
224  // Note: From a correctness point of view, materializations do not have to
225  // be stored (and looked up) in the mapping. But for performance reasons,
226  // we choose to reuse existing IR (when possible) instead of creating it
227  // multiple times.
228  auto it = mapping.find(current);
229  if (it == mapping.end()) {
230  // No mapping found: The lookup stops here.
231  break;
232  }
233  current = it->second;
234  } while (true);
235 
236  // If the desired values were found use them, otherwise default to the leaf
237  // values.
238  // Note: If `desiredTypes` is empty, this function always returns `current`.
239  return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
240 }
241 
242 ValueVector ConversionValueMapping::lookupOrNull(Value from,
243  TypeRange desiredTypes) const {
244  ValueVector result = lookupOrDefault(from, desiredTypes);
245  if (result == ValueVector{from} ||
246  (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
247  return {};
248  return result;
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // Rewriter and Translation State
253 //===----------------------------------------------------------------------===//
254 namespace {
255 /// This class contains a snapshot of the current conversion rewriter state.
256 /// This is useful when saving and undoing a set of rewrites.
257 struct RewriterState {
258  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
259  unsigned numReplacedOps)
260  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
261  numReplacedOps(numReplacedOps) {}
262 
263  /// The current number of rewrites performed.
264  unsigned numRewrites;
265 
266  /// The current number of ignored operations.
267  unsigned numIgnoredOperations;
268 
269  /// The current number of replaced ops that are scheduled for erasure.
270  unsigned numReplacedOps;
271 };
272 
273 //===----------------------------------------------------------------------===//
274 // IR rewrites
275 //===----------------------------------------------------------------------===//
276 
277 /// An IR rewrite that can be committed (upon success) or rolled back (upon
278 /// failure).
279 ///
280 /// The dialect conversion keeps track of IR modifications (requested by the
281 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
282 /// are directly applied to the IR as the rewriter API is used, some are applied
283 /// partially, and some are delayed until the `IRRewrite` objects are committed.
284 class IRRewrite {
285 public:
286  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
287  /// Enum values are ordered, so that they can be used in `classof`: first all
288  /// block rewrites, then all operation rewrites.
289  enum class Kind {
290  // Block rewrites
291  CreateBlock,
292  EraseBlock,
293  InlineBlock,
294  MoveBlock,
295  BlockTypeConversion,
296  ReplaceBlockArg,
297  // Operation rewrites
298  MoveOperation,
299  ModifyOperation,
300  ReplaceOperation,
301  CreateOperation,
302  UnresolvedMaterialization
303  };
304 
305  virtual ~IRRewrite() = default;
306 
307  /// Roll back the rewrite. Operations may be erased during rollback.
308  virtual void rollback() = 0;
309 
310  /// Commit the rewrite. At this point, it is certain that the dialect
311  /// conversion will succeed. All IR modifications, except for operation/block
312  /// erasure, must be performed through the given rewriter.
313  ///
314  /// Instead of erasing operations/blocks, they should merely be unlinked
315  /// commit phase and finally be erased during the cleanup phase. This is
316  /// because internal dialect conversion state (such as `mapping`) may still
317  /// be using them.
318  ///
319  /// Any IR modification that was already performed before the commit phase
320  /// (e.g., insertion of an op) must be communicated to the listener that may
321  /// be attached to the given rewriter.
322  virtual void commit(RewriterBase &rewriter) {}
323 
324  /// Cleanup operations/blocks. Cleanup is called after commit.
325  virtual void cleanup(RewriterBase &rewriter) {}
326 
327  Kind getKind() const { return kind; }
328 
329  static bool classof(const IRRewrite *rewrite) { return true; }
330 
331 protected:
332  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
333  : kind(kind), rewriterImpl(rewriterImpl) {}
334 
335  const ConversionConfig &getConfig() const;
336 
337  const Kind kind;
338  ConversionPatternRewriterImpl &rewriterImpl;
339 };
340 
341 /// A block rewrite.
342 class BlockRewrite : public IRRewrite {
343 public:
344  /// Return the block that this rewrite operates on.
345  Block *getBlock() const { return block; }
346 
347  static bool classof(const IRRewrite *rewrite) {
348  return rewrite->getKind() >= Kind::CreateBlock &&
349  rewrite->getKind() <= Kind::ReplaceBlockArg;
350  }
351 
352 protected:
353  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
354  Block *block)
355  : IRRewrite(kind, rewriterImpl), block(block) {}
356 
357  // The block that this rewrite operates on.
358  Block *block;
359 };
360 
361 /// Creation of a block. Block creations are immediately reflected in the IR.
362 /// There is no extra work to commit the rewrite. During rollback, the newly
363 /// created block is erased.
364 class CreateBlockRewrite : public BlockRewrite {
365 public:
366  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
367  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
368 
369  static bool classof(const IRRewrite *rewrite) {
370  return rewrite->getKind() == Kind::CreateBlock;
371  }
372 
373  void commit(RewriterBase &rewriter) override {
374  // The block was already created and inserted. Just inform the listener.
375  if (auto *listener = rewriter.getListener())
376  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
377  }
378 
379  void rollback() override {
380  // Unlink all of the operations within this block, they will be deleted
381  // separately.
382  auto &blockOps = block->getOperations();
383  while (!blockOps.empty())
384  blockOps.remove(blockOps.begin());
385  block->dropAllUses();
386  if (block->getParent())
387  block->erase();
388  else
389  delete block;
390  }
391 };
392 
393 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
394 /// blocks are immediately unlinked, but only erased during cleanup. This makes
395 /// it easier to rollback a block erasure: the block is simply inserted into its
396 /// original location.
397 class EraseBlockRewrite : public BlockRewrite {
398 public:
399  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
400  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
401  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
402 
403  static bool classof(const IRRewrite *rewrite) {
404  return rewrite->getKind() == Kind::EraseBlock;
405  }
406 
407  ~EraseBlockRewrite() override {
408  assert(!block &&
409  "rewrite was neither rolled back nor committed/cleaned up");
410  }
411 
412  void rollback() override {
413  // The block (owned by this rewrite) was not actually erased yet. It was
414  // just unlinked. Put it back into its original position.
415  assert(block && "expected block");
416  auto &blockList = region->getBlocks();
417  Region::iterator before = insertBeforeBlock
418  ? Region::iterator(insertBeforeBlock)
419  : blockList.end();
420  blockList.insert(before, block);
421  block = nullptr;
422  }
423 
424  void commit(RewriterBase &rewriter) override {
425  // Erase the block.
426  assert(block && "expected block");
427  assert(block->empty() && "expected empty block");
428 
429  // Notify the listener that the block is about to be erased.
430  if (auto *listener =
431  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
432  listener->notifyBlockErased(block);
433  }
434 
435  void cleanup(RewriterBase &rewriter) override {
436  // Erase the block.
437  block->dropAllDefinedValueUses();
438  delete block;
439  block = nullptr;
440  }
441 
442 private:
443  // The region in which this block was previously contained.
444  Region *region;
445 
446  // The original successor of this block before it was unlinked. "nullptr" if
447  // this block was the only block in the region.
448  Block *insertBeforeBlock;
449 };
450 
451 /// Inlining of a block. This rewrite is immediately reflected in the IR.
452 /// Note: This rewrite represents only the inlining of the operations. The
453 /// erasure of the inlined block is a separate rewrite.
454 class InlineBlockRewrite : public BlockRewrite {
455 public:
456  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
457  Block *sourceBlock, Block::iterator before)
458  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
459  sourceBlock(sourceBlock),
460  firstInlinedInst(sourceBlock->empty() ? nullptr
461  : &sourceBlock->front()),
462  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
463  // If a listener is attached to the dialect conversion, ops must be moved
464  // one-by-one. When they are moved in bulk, notifications cannot be sent
465  // because the ops that used to be in the source block at the time of the
466  // inlining (before the "commit" phase) are unknown at the time when
467  // notifications are sent (which is during the "commit" phase).
468  assert(!getConfig().listener &&
469  "InlineBlockRewrite not supported if listener is attached");
470  }
471 
472  static bool classof(const IRRewrite *rewrite) {
473  return rewrite->getKind() == Kind::InlineBlock;
474  }
475 
476  void rollback() override {
477  // Put the operations from the destination block (owned by the rewrite)
478  // back into the source block.
479  if (firstInlinedInst) {
480  assert(lastInlinedInst && "expected operation");
481  sourceBlock->getOperations().splice(sourceBlock->begin(),
482  block->getOperations(),
483  Block::iterator(firstInlinedInst),
484  ++Block::iterator(lastInlinedInst));
485  }
486  }
487 
488 private:
489  // The block that originally contained the operations.
490  Block *sourceBlock;
491 
492  // The first inlined operation.
493  Operation *firstInlinedInst;
494 
495  // The last inlined operation.
496  Operation *lastInlinedInst;
497 };
498 
499 /// Moving of a block. This rewrite is immediately reflected in the IR.
500 class MoveBlockRewrite : public BlockRewrite {
501 public:
502  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
503  Region *region, Block *insertBeforeBlock)
504  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
505  insertBeforeBlock(insertBeforeBlock) {}
506 
507  static bool classof(const IRRewrite *rewrite) {
508  return rewrite->getKind() == Kind::MoveBlock;
509  }
510 
511  void commit(RewriterBase &rewriter) override {
512  // The block was already moved. Just inform the listener.
513  if (auto *listener = rewriter.getListener()) {
514  // Note: `previousIt` cannot be passed because this is a delayed
515  // notification and iterators into past IR state cannot be represented.
516  listener->notifyBlockInserted(block, /*previous=*/region,
517  /*previousIt=*/{});
518  }
519  }
520 
521  void rollback() override {
522  // Move the block back to its original position.
523  Region::iterator before =
524  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
525  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
526  }
527 
528 private:
529  // The region in which this block was previously contained.
530  Region *region;
531 
532  // The original successor of this block before it was moved. "nullptr" if
533  // this block was the only block in the region.
534  Block *insertBeforeBlock;
535 };
536 
537 /// Block type conversion. This rewrite is partially reflected in the IR.
538 class BlockTypeConversionRewrite : public BlockRewrite {
539 public:
540  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
541  Block *origBlock, Block *newBlock)
542  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
543  newBlock(newBlock) {}
544 
545  static bool classof(const IRRewrite *rewrite) {
546  return rewrite->getKind() == Kind::BlockTypeConversion;
547  }
548 
549  Block *getOrigBlock() const { return block; }
550 
551  Block *getNewBlock() const { return newBlock; }
552 
553  void commit(RewriterBase &rewriter) override;
554 
555  void rollback() override;
556 
557 private:
558  /// The new block that was created as part of this signature conversion.
559  Block *newBlock;
560 };
561 
562 /// Replacing a block argument. This rewrite is not immediately reflected in the
563 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
564 /// until the rewrite is committed.
565 class ReplaceBlockArgRewrite : public BlockRewrite {
566 public:
567  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
568  Block *block, BlockArgument arg,
569  const TypeConverter *converter)
570  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
571  converter(converter) {}
572 
573  static bool classof(const IRRewrite *rewrite) {
574  return rewrite->getKind() == Kind::ReplaceBlockArg;
575  }
576 
577  void commit(RewriterBase &rewriter) override;
578 
579  void rollback() override;
580 
581 private:
582  BlockArgument arg;
583 
584  /// The current type converter when the block argument was replaced.
585  const TypeConverter *converter;
586 };
587 
588 /// An operation rewrite.
589 class OperationRewrite : public IRRewrite {
590 public:
591  /// Return the operation that this rewrite operates on.
592  Operation *getOperation() const { return op; }
593 
594  static bool classof(const IRRewrite *rewrite) {
595  return rewrite->getKind() >= Kind::MoveOperation &&
596  rewrite->getKind() <= Kind::UnresolvedMaterialization;
597  }
598 
599 protected:
600  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
601  Operation *op)
602  : IRRewrite(kind, rewriterImpl), op(op) {}
603 
604  // The operation that this rewrite operates on.
605  Operation *op;
606 };
607 
608 /// Moving of an operation. This rewrite is immediately reflected in the IR.
609 class MoveOperationRewrite : public OperationRewrite {
610 public:
611  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
612  Operation *op, Block *block, Operation *insertBeforeOp)
613  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
614  insertBeforeOp(insertBeforeOp) {}
615 
616  static bool classof(const IRRewrite *rewrite) {
617  return rewrite->getKind() == Kind::MoveOperation;
618  }
619 
620  void commit(RewriterBase &rewriter) override {
621  // The operation was already moved. Just inform the listener.
622  if (auto *listener = rewriter.getListener()) {
623  // Note: `previousIt` cannot be passed because this is a delayed
624  // notification and iterators into past IR state cannot be represented.
625  listener->notifyOperationInserted(
626  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
627  /*insertPt=*/{}));
628  }
629  }
630 
631  void rollback() override {
632  // Move the operation back to its original position.
633  Block::iterator before =
634  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
635  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
636  }
637 
638 private:
639  // The block in which this operation was previously contained.
640  Block *block;
641 
642  // The original successor of this operation before it was moved. "nullptr"
643  // if this operation was the only operation in the region.
644  Operation *insertBeforeOp;
645 };
646 
647 /// In-place modification of an op. This rewrite is immediately reflected in
648 /// the IR. The previous state of the operation is stored in this object.
649 class ModifyOperationRewrite : public OperationRewrite {
650 public:
651  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
652  Operation *op)
653  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
654  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
655  operands(op->operand_begin(), op->operand_end()),
656  successors(op->successor_begin(), op->successor_end()) {
657  if (OpaqueProperties prop = op->getPropertiesStorage()) {
658  // Make a copy of the properties.
659  propertiesStorage = operator new(op->getPropertiesStorageSize());
660  OpaqueProperties propCopy(propertiesStorage);
661  name.initOpProperties(propCopy, /*init=*/prop);
662  }
663  }
664 
665  static bool classof(const IRRewrite *rewrite) {
666  return rewrite->getKind() == Kind::ModifyOperation;
667  }
668 
669  ~ModifyOperationRewrite() override {
670  assert(!propertiesStorage &&
671  "rewrite was neither committed nor rolled back");
672  }
673 
674  void commit(RewriterBase &rewriter) override {
675  // Notify the listener that the operation was modified in-place.
676  if (auto *listener =
677  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
678  listener->notifyOperationModified(op);
679 
680  if (propertiesStorage) {
681  OpaqueProperties propCopy(propertiesStorage);
682  // Note: The operation may have been erased in the mean time, so
683  // OperationName must be stored in this object.
684  name.destroyOpProperties(propCopy);
685  operator delete(propertiesStorage);
686  propertiesStorage = nullptr;
687  }
688  }
689 
690  void rollback() override {
691  op->setLoc(loc);
692  op->setAttrs(attrs);
693  op->setOperands(operands);
694  for (const auto &it : llvm::enumerate(successors))
695  op->setSuccessor(it.value(), it.index());
696  if (propertiesStorage) {
697  OpaqueProperties propCopy(propertiesStorage);
698  op->copyProperties(propCopy);
699  name.destroyOpProperties(propCopy);
700  operator delete(propertiesStorage);
701  propertiesStorage = nullptr;
702  }
703  }
704 
705 private:
706  OperationName name;
707  LocationAttr loc;
708  DictionaryAttr attrs;
709  SmallVector<Value, 8> operands;
710  SmallVector<Block *, 2> successors;
711  void *propertiesStorage = nullptr;
712 };
713 
714 /// Replacing an operation. Erasing an operation is treated as a special case
715 /// with "null" replacements. This rewrite is not immediately reflected in the
716 /// IR. An internal IR mapping is updated, but values are not replaced and the
717 /// original op is not erased until the rewrite is committed.
718 class ReplaceOperationRewrite : public OperationRewrite {
719 public:
720  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
721  Operation *op, const TypeConverter *converter)
722  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
723  converter(converter) {}
724 
725  static bool classof(const IRRewrite *rewrite) {
726  return rewrite->getKind() == Kind::ReplaceOperation;
727  }
728 
729  void commit(RewriterBase &rewriter) override;
730 
731  void rollback() override;
732 
733  void cleanup(RewriterBase &rewriter) override;
734 
735 private:
736  /// An optional type converter that can be used to materialize conversions
737  /// between the new and old values if necessary.
738  const TypeConverter *converter;
739 };
740 
741 class CreateOperationRewrite : public OperationRewrite {
742 public:
743  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
744  Operation *op)
745  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
746 
747  static bool classof(const IRRewrite *rewrite) {
748  return rewrite->getKind() == Kind::CreateOperation;
749  }
750 
751  void commit(RewriterBase &rewriter) override {
752  // The operation was already created and inserted. Just inform the listener.
753  if (auto *listener = rewriter.getListener())
754  listener->notifyOperationInserted(op, /*previous=*/{});
755  }
756 
757  void rollback() override;
758 };
759 
760 /// The type of materialization.
761 enum MaterializationKind {
762  /// This materialization materializes a conversion from an illegal type to a
763  /// legal one.
764  Target,
765 
766  /// This materialization materializes a conversion from a legal type back to
767  /// an illegal one.
768  Source
769 };
770 
771 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
772 /// op. Unresolved materializations are erased at the end of the dialect
773 /// conversion.
774 class UnresolvedMaterializationRewrite : public OperationRewrite {
775 public:
776  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
777  UnrealizedConversionCastOp op,
778  const TypeConverter *converter,
779  MaterializationKind kind, Type originalType,
780  ValueVector mappedValues);
781 
782  static bool classof(const IRRewrite *rewrite) {
783  return rewrite->getKind() == Kind::UnresolvedMaterialization;
784  }
785 
786  void rollback() override;
787 
788  UnrealizedConversionCastOp getOperation() const {
789  return cast<UnrealizedConversionCastOp>(op);
790  }
791 
792  /// Return the type converter of this materialization (which may be null).
793  const TypeConverter *getConverter() const {
794  return converterAndKind.getPointer();
795  }
796 
797  /// Return the kind of this materialization.
798  MaterializationKind getMaterializationKind() const {
799  return converterAndKind.getInt();
800  }
801 
802  /// Return the original type of the SSA value.
803  Type getOriginalType() const { return originalType; }
804 
805 private:
806  /// The corresponding type converter to use when resolving this
807  /// materialization, and the kind of this materialization.
808  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
809  converterAndKind;
810 
811  /// The original type of the SSA value. Only used for target
812  /// materializations.
813  Type originalType;
814 
815  /// The values in the conversion value mapping that are being replaced by the
816  /// results of this unresolved materialization.
817  ValueVector mappedValues;
818 };
819 } // namespace
820 
821 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
822 /// Return "true" if there is an operation rewrite that matches the specified
823 /// rewrite type and operation among the given rewrites.
824 template <typename RewriteTy, typename R>
825 static bool hasRewrite(R &&rewrites, Operation *op) {
826  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
827  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
828  return rewriteTy && rewriteTy->getOperation() == op;
829  });
830 }
831 
832 /// Return "true" if there is a block rewrite that matches the specified
833 /// rewrite type and block among the given rewrites.
834 template <typename RewriteTy, typename R>
835 static bool hasRewrite(R &&rewrites, Block *block) {
836  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
837  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
838  return rewriteTy && rewriteTy->getBlock() == block;
839  });
840 }
841 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
842 
843 //===----------------------------------------------------------------------===//
844 // ConversionPatternRewriterImpl
845 //===----------------------------------------------------------------------===//
846 namespace mlir {
847 namespace detail {
850  const ConversionConfig &config)
851  : context(ctx), config(config) {}
852 
853  //===--------------------------------------------------------------------===//
854  // State Management
855  //===--------------------------------------------------------------------===//
856 
857  /// Return the current state of the rewriter.
858  RewriterState getCurrentState();
859 
860  /// Apply all requested operation rewrites. This method is invoked when the
861  /// conversion process succeeds.
862  void applyRewrites();
863 
864  /// Reset the state of the rewriter to a previously saved point. Optionally,
865  /// the name of the pattern that triggered the rollback can specified for
866  /// debugging purposes.
867  void resetState(RewriterState state, StringRef patternName = "");
868 
869  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
870  /// failure.
871  template <typename RewriteTy, typename... Args>
872  void appendRewrite(Args &&...args) {
873  rewrites.push_back(
874  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
875  }
876 
877  /// Undo the rewrites (motions, splits) one by one in reverse order until
878  /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
879  /// that triggered the rollback can specified for debugging purposes.
880  void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
881 
882  /// Remap the given values to those with potentially different types. Returns
883  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
884  /// is the tag used when describing a value within a diagnostic, e.g.
885  /// "operand".
886  LogicalResult remapValues(StringRef valueDiagTag,
887  std::optional<Location> inputLoc,
888  PatternRewriter &rewriter, ValueRange values,
889  SmallVector<ValueVector> &remapped);
890 
891  /// Return "true" if the given operation is ignored, and does not need to be
892  /// converted.
893  bool isOpIgnored(Operation *op) const;
894 
895  /// Return "true" if the given operation was replaced or erased.
896  bool wasOpReplaced(Operation *op) const;
897 
898  //===--------------------------------------------------------------------===//
899  // Type Conversion
900  //===--------------------------------------------------------------------===//
901 
902  /// Convert the types of block arguments within the given region.
903  FailureOr<Block *>
904  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
905  const TypeConverter &converter,
906  TypeConverter::SignatureConversion *entryConversion);
907 
908  /// Apply the given signature conversion on the given block. The new block
909  /// containing the updated signature is returned. If no conversions were
910  /// necessary, e.g. if the block has no arguments, `block` is returned.
911  /// `converter` is used to generate any necessary cast operations that
912  /// translate between the origin argument types and those specified in the
913  /// signature conversion.
914  Block *applySignatureConversion(
915  ConversionPatternRewriter &rewriter, Block *block,
916  const TypeConverter *converter,
917  TypeConverter::SignatureConversion &signatureConversion);
918 
919  //===--------------------------------------------------------------------===//
920  // Materializations
921  //===--------------------------------------------------------------------===//
922 
923  /// Build an unresolved materialization operation given a range of output
924  /// types and a list of input operands. Returns the inputs if they their
925  /// types match the output types.
926  ///
927  /// If a cast op was built, it can optionally be returned with the `castOp`
928  /// output argument.
929  ///
930  /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
931  /// the results of the unresolved materialization in the conversion value
932  /// mapping.
933  ValueRange buildUnresolvedMaterialization(
934  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
935  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
936  Type originalType, const TypeConverter *converter,
937  UnrealizedConversionCastOp *castOp = nullptr);
938 
939  /// Find a replacement value for the given SSA value in the conversion value
940  /// mapping. The replacement value must have the same type as the given SSA
941  /// value. If there is no replacement value with the correct type, find the
942  /// latest replacement value (regardless of the type) and build a source
943  /// materialization.
944  Value findOrBuildReplacementValue(Value value,
945  const TypeConverter *converter);
946 
947  //===--------------------------------------------------------------------===//
948  // Rewriter Notification Hooks
949  //===--------------------------------------------------------------------===//
950 
951  //// Notifies that an op was inserted.
952  void notifyOperationInserted(Operation *op,
953  OpBuilder::InsertPoint previous) override;
954 
955  /// Notifies that an op is about to be replaced with the given values.
956  void notifyOpReplaced(Operation *op,
957  SmallVector<SmallVector<Value>> &&newValues);
958 
959  /// Notifies that a block is about to be erased.
960  void notifyBlockIsBeingErased(Block *block);
961 
962  /// Notifies that a block was inserted.
963  void notifyBlockInserted(Block *block, Region *previous,
964  Region::iterator previousIt) override;
965 
966  /// Notifies that a block is being inlined into another block.
967  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
968  Block::iterator before);
969 
970  /// Notifies that a pattern match failed for the given reason.
971  void
972  notifyMatchFailure(Location loc,
973  function_ref<void(Diagnostic &)> reasonCallback) override;
974 
975  //===--------------------------------------------------------------------===//
976  // IR Erasure
977  //===--------------------------------------------------------------------===//
978 
979  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
980  /// operation or block is erased multiple times. This rewriter assumes that
981  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
983  public:
985  MLIRContext *context,
986  std::function<void(Operation *)> opErasedCallback = nullptr)
987  : RewriterBase(context, /*listener=*/this),
988  opErasedCallback(opErasedCallback) {}
989 
990  /// Erase the given op (unless it was already erased).
991  void eraseOp(Operation *op) override {
992  if (wasErased(op))
993  return;
994  op->dropAllUses();
996  }
997 
998  /// Erase the given block (unless it was already erased).
999  void eraseBlock(Block *block) override {
1000  if (wasErased(block))
1001  return;
1002  assert(block->empty() && "expected empty block");
1003  block->dropAllDefinedValueUses();
1004  RewriterBase::eraseBlock(block);
1005  }
1006 
1007  bool wasErased(void *ptr) const { return erased.contains(ptr); }
1008 
1009  void notifyOperationErased(Operation *op) override {
1010  erased.insert(op);
1011  if (opErasedCallback)
1012  opErasedCallback(op);
1013  }
1014 
1015  void notifyBlockErased(Block *block) override { erased.insert(block); }
1016 
1017  private:
1018  /// Pointers to all erased operations and blocks.
1019  DenseSet<void *> erased;
1020 
1021  /// A callback that is invoked when an operation is erased.
1022  std::function<void(Operation *)> opErasedCallback;
1023  };
1024 
1025  //===--------------------------------------------------------------------===//
1026  // State
1027  //===--------------------------------------------------------------------===//
1028 
1029  /// MLIR context.
1031 
1032  // Mapping between replaced values that differ in type. This happens when
1033  // replacing a value with one of a different type.
1034  ConversionValueMapping mapping;
1035 
1036  /// Ordered list of block operations (creations, splits, motions).
1038 
1039  /// A set of operations that should no longer be considered for legalization.
1040  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1041  /// tracked separately.
1043 
1044  /// A set of operations that were replaced/erased. Such ops are not erased
1045  /// immediately but only when the dialect conversion succeeds. In the mean
1046  /// time, they should no longer be considered for legalization and any attempt
1047  /// to modify/access them is invalid rewriter API usage.
1049 
1050  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1051  /// to the corresponding rewrite objects.
1054 
1055  /// The current type converter, or nullptr if no type converter is currently
1056  /// active.
1057  const TypeConverter *currentTypeConverter = nullptr;
1058 
1059  /// A mapping of regions to type converters that should be used when
1060  /// converting the arguments of blocks within that region.
1062 
1063  /// Dialect conversion configuration.
1065 
1066 #ifndef NDEBUG
1067  /// A set of operations that have pending updates. This tracking isn't
1068  /// strictly necessary, and is thus only active during debug builds for extra
1069  /// verification.
1071 
1072  /// A logger used to emit diagnostics during the conversion process.
1073  llvm::ScopedPrinter logger{llvm::dbgs()};
1074 #endif
1075 };
1076 } // namespace detail
1077 } // namespace mlir
1078 
1079 const ConversionConfig &IRRewrite::getConfig() const {
1080  return rewriterImpl.config;
1081 }
1082 
1083 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1084  // Inform the listener about all IR modifications that have already taken
1085  // place: References to the original block have been replaced with the new
1086  // block.
1087  if (auto *listener =
1088  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1089  for (Operation *op : getNewBlock()->getUsers())
1090  listener->notifyOperationModified(op);
1091 }
1092 
1093 void BlockTypeConversionRewrite::rollback() {
1094  getNewBlock()->replaceAllUsesWith(getOrigBlock());
1095 }
1096 
1097 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1098  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
1099  if (!repl)
1100  return;
1101 
1102  if (isa<BlockArgument>(repl)) {
1103  rewriter.replaceAllUsesWith(arg, repl);
1104  return;
1105  }
1106 
1107  // If the replacement value is an operation, we check to make sure that we
1108  // don't replace uses that are within the parent operation of the
1109  // replacement value.
1110  Operation *replOp = cast<OpResult>(repl).getOwner();
1111  Block *replBlock = replOp->getBlock();
1112  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1113  Operation *user = operand.getOwner();
1114  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1115  });
1116 }
1117 
1118 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
1119 
1120 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1121  auto *listener =
1122  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1123 
1124  // Compute replacement values.
1125  SmallVector<Value> replacements =
1126  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1127  return rewriterImpl.findOrBuildReplacementValue(result, converter);
1128  });
1129 
1130  // Notify the listener that the operation is about to be replaced.
1131  if (listener)
1132  listener->notifyOperationReplaced(op, replacements);
1133 
1134  // Replace all uses with the new values.
1135  for (auto [result, newValue] :
1136  llvm::zip_equal(op->getResults(), replacements))
1137  if (newValue)
1138  rewriter.replaceAllUsesWith(result, newValue);
1139 
1140  // The original op will be erased, so remove it from the set of unlegalized
1141  // ops.
1142  if (getConfig().unlegalizedOps)
1143  getConfig().unlegalizedOps->erase(op);
1144 
1145  // Notify the listener that the operation (and its nested operations) was
1146  // erased.
1147  if (listener) {
1149  [&](Operation *op) { listener->notifyOperationErased(op); });
1150  }
1151 
1152  // Do not erase the operation yet. It may still be referenced in `mapping`.
1153  // Just unlink it for now and erase it during cleanup.
1154  op->getBlock()->getOperations().remove(op);
1155 }
1156 
1157 void ReplaceOperationRewrite::rollback() {
1158  for (auto result : op->getResults())
1159  rewriterImpl.mapping.erase({result});
1160 }
1161 
1162 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1163  rewriter.eraseOp(op);
1164 }
1165 
1166 void CreateOperationRewrite::rollback() {
1167  for (Region &region : op->getRegions()) {
1168  while (!region.getBlocks().empty())
1169  region.getBlocks().remove(region.getBlocks().begin());
1170  }
1171  op->dropAllUses();
1172  op->erase();
1173 }
1174 
1175 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1176  ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1177  const TypeConverter *converter, MaterializationKind kind, Type originalType,
1178  ValueVector mappedValues)
1179  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1180  converterAndKind(converter, kind), originalType(originalType),
1181  mappedValues(std::move(mappedValues)) {
1182  assert((!originalType || kind == MaterializationKind::Target) &&
1183  "original type is valid only for target materializations");
1184  rewriterImpl.unresolvedMaterializations[op] = this;
1185 }
1186 
1187 void UnresolvedMaterializationRewrite::rollback() {
1188  if (!mappedValues.empty())
1189  rewriterImpl.mapping.erase(mappedValues);
1190  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1191  op->erase();
1192 }
1193 
1195  // Commit all rewrites.
1196  IRRewriter rewriter(context, config.listener);
1197  // Note: New rewrites may be added during the "commit" phase and the
1198  // `rewrites` vector may reallocate.
1199  for (size_t i = 0; i < rewrites.size(); ++i)
1200  rewrites[i]->commit(rewriter);
1201 
1202  // Clean up all rewrites.
1203  SingleEraseRewriter eraseRewriter(
1204  context, /*opErasedCallback=*/[&](Operation *op) {
1205  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1206  unresolvedMaterializations.erase(castOp);
1207  });
1208  for (auto &rewrite : rewrites)
1209  rewrite->cleanup(eraseRewriter);
1210 }
1211 
1212 //===----------------------------------------------------------------------===//
1213 // State Management
1214 //===----------------------------------------------------------------------===//
1215 
1217  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1218 }
1219 
1221  StringRef patternName) {
1222  // Undo any rewrites.
1223  undoRewrites(state.numRewrites, patternName);
1224 
1225  // Pop all of the recorded ignored operations that are no longer valid.
1226  while (ignoredOps.size() != state.numIgnoredOperations)
1227  ignoredOps.pop_back();
1228 
1229  while (replacedOps.size() != state.numReplacedOps)
1230  replacedOps.pop_back();
1231 }
1232 
1233 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1234  StringRef patternName) {
1235  for (auto &rewrite :
1236  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
1238  !isa<UnresolvedMaterializationRewrite>(rewrite)) {
1239  // Unresolved materializations can always be rolled back (erased).
1240  llvm::report_fatal_error("pattern '" + patternName +
1241  "' rollback of IR modifications requested");
1242  }
1243  rewrite->rollback();
1244  }
1245  rewrites.resize(numRewritesToKeep);
1246 }
1247 
1249  StringRef valueDiagTag, std::optional<Location> inputLoc,
1250  PatternRewriter &rewriter, ValueRange values,
1251  SmallVector<ValueVector> &remapped) {
1252  remapped.reserve(llvm::size(values));
1253 
1254  for (const auto &it : llvm::enumerate(values)) {
1255  Value operand = it.value();
1256  Type origType = operand.getType();
1257  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1258 
1259  if (!currentTypeConverter) {
1260  // The current pattern does not have a type converter. I.e., it does not
1261  // distinguish between legal and illegal types. For each operand, simply
1262  // pass through the most recently mapped values.
1263  remapped.push_back(mapping.lookupOrDefault(operand));
1264  continue;
1265  }
1266 
1267  // If there is no legal conversion, fail to match this pattern.
1268  SmallVector<Type, 1> legalTypes;
1269  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1270  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1271  diag << "unable to convert type for " << valueDiagTag << " #"
1272  << it.index() << ", type was " << origType;
1273  });
1274  return failure();
1275  }
1276  // If a type is converted to 0 types, there is nothing to do.
1277  if (legalTypes.empty()) {
1278  remapped.push_back({});
1279  continue;
1280  }
1281 
1282  ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
1283  if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1284  // Mapped values have the correct type or there is an existing
1285  // materialization. Or the operand is not mapped at all and has the
1286  // correct type.
1287  remapped.push_back(std::move(repl));
1288  continue;
1289  }
1290 
1291  // Create a materialization for the most recently mapped values.
1292  repl = mapping.lookupOrDefault(operand);
1294  MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1295  /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1296  /*originalType=*/origType, currentTypeConverter);
1297  remapped.push_back(castValues);
1298  }
1299  return success();
1300 }
1301 
1303  // Check to see if this operation is ignored or was replaced.
1304  return replacedOps.count(op) || ignoredOps.count(op);
1305 }
1306 
1308  // Check to see if this operation was replaced.
1309  return replacedOps.count(op);
1310 }
1311 
1312 //===----------------------------------------------------------------------===//
1313 // Type Conversion
1314 //===----------------------------------------------------------------------===//
1315 
1317  ConversionPatternRewriter &rewriter, Region *region,
1318  const TypeConverter &converter,
1319  TypeConverter::SignatureConversion *entryConversion) {
1320  regionToConverter[region] = &converter;
1321  if (region->empty())
1322  return nullptr;
1323 
1324  // Convert the arguments of each non-entry block within the region.
1325  for (Block &block :
1326  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1327  // Compute the signature for the block with the provided converter.
1328  std::optional<TypeConverter::SignatureConversion> conversion =
1329  converter.convertBlockSignature(&block);
1330  if (!conversion)
1331  return failure();
1332  // Convert the block with the computed signature.
1333  applySignatureConversion(rewriter, &block, &converter, *conversion);
1334  }
1335 
1336  // Convert the entry block. If an entry signature conversion was provided,
1337  // use that one. Otherwise, compute the signature with the type converter.
1338  if (entryConversion)
1339  return applySignatureConversion(rewriter, &region->front(), &converter,
1340  *entryConversion);
1341  std::optional<TypeConverter::SignatureConversion> conversion =
1342  converter.convertBlockSignature(&region->front());
1343  if (!conversion)
1344  return failure();
1345  return applySignatureConversion(rewriter, &region->front(), &converter,
1346  *conversion);
1347 }
1348 
1350  ConversionPatternRewriter &rewriter, Block *block,
1351  const TypeConverter *converter,
1352  TypeConverter::SignatureConversion &signatureConversion) {
1353 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1354  // A block cannot be converted multiple times.
1355  if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1356  llvm::report_fatal_error("block was already converted");
1357 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1358 
1359  OpBuilder::InsertionGuard g(rewriter);
1360 
1361  // If no arguments are being changed or added, there is nothing to do.
1362  unsigned origArgCount = block->getNumArguments();
1363  auto convertedTypes = signatureConversion.getConvertedTypes();
1364  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1365  return block;
1366 
1367  // Compute the locations of all block arguments in the new block.
1368  SmallVector<Location> newLocs(convertedTypes.size(),
1369  rewriter.getUnknownLoc());
1370  for (unsigned i = 0; i < origArgCount; ++i) {
1371  auto inputMap = signatureConversion.getInputMapping(i);
1372  if (!inputMap || inputMap->replacedWithValues())
1373  continue;
1374  Location origLoc = block->getArgument(i).getLoc();
1375  for (unsigned j = 0; j < inputMap->size; ++j)
1376  newLocs[inputMap->inputNo + j] = origLoc;
1377  }
1378 
1379  // Insert a new block with the converted block argument types and move all ops
1380  // from the old block to the new block.
1381  Block *newBlock =
1382  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1383  convertedTypes, newLocs);
1384 
1385  // If a listener is attached to the dialect conversion, ops cannot be moved
1386  // to the destination block in bulk ("fast path"). This is because at the time
1387  // the notifications are sent, it is unknown which ops were moved. Instead,
1388  // ops should be moved one-by-one ("slow path"), so that a separate
1389  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1390  // a bit more efficient, so we try to do that when possible.
1391  bool fastPath = !config.listener;
1392  if (fastPath) {
1393  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1394  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1395  } else {
1396  while (!block->empty())
1397  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1398  }
1399 
1400  // Replace all uses of the old block with the new block.
1401  block->replaceAllUsesWith(newBlock);
1402 
1403  for (unsigned i = 0; i != origArgCount; ++i) {
1404  BlockArgument origArg = block->getArgument(i);
1405  Type origArgType = origArg.getType();
1406 
1407  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1408  signatureConversion.getInputMapping(i);
1409  if (!inputMap) {
1410  // This block argument was dropped and no replacement value was provided.
1411  // Materialize a replacement value "out of thin air".
1413  MaterializationKind::Source,
1414  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1415  /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
1416  /*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
1417  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1418  continue;
1419  }
1420 
1421  if (inputMap->replacedWithValues()) {
1422  // This block argument was dropped and replacement values were provided.
1423  assert(inputMap->size == 0 &&
1424  "invalid to provide a replacement value when the argument isn't "
1425  "dropped");
1426  mapping.map(origArg, inputMap->replacementValues);
1427  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1428  continue;
1429  }
1430 
1431  // This is a 1->1+ mapping.
1432  auto replArgs =
1433  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1434  ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
1435  mapping.map(origArg, std::move(replArgVals));
1436  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1437  }
1438 
1439  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1440 
1441  // Erase the old block. (It is just unlinked for now and will be erased during
1442  // cleanup.)
1443  rewriter.eraseBlock(block);
1444 
1445  return newBlock;
1446 }
1447 
1448 //===----------------------------------------------------------------------===//
1449 // Materializations
1450 //===----------------------------------------------------------------------===//
1451 
1452 /// Build an unresolved materialization operation given an output type and set
1453 /// of input operands.
1455  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1456  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1457  Type originalType, const TypeConverter *converter,
1458  UnrealizedConversionCastOp *castOp) {
1459  assert((!originalType || kind == MaterializationKind::Target) &&
1460  "original type is valid only for target materializations");
1461  assert(TypeRange(inputs) != outputTypes &&
1462  "materialization is not necessary");
1463 
1464  // Create an unresolved materialization. We use a new OpBuilder to avoid
1465  // tracking the materialization like we do for other operations.
1466  OpBuilder builder(outputTypes.front().getContext());
1467  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1468  auto convertOp =
1469  builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
1470  if (!valuesToMap.empty())
1471  mapping.map(valuesToMap, convertOp.getResults());
1472  if (castOp)
1473  *castOp = convertOp;
1474  appendRewrite<UnresolvedMaterializationRewrite>(
1475  convertOp, converter, kind, originalType, std::move(valuesToMap));
1476  return convertOp.getResults();
1477 }
1478 
1480  Value value, const TypeConverter *converter) {
1481  // Try to find a replacement value with the same type in the conversion value
1482  // mapping. This includes cached materializations. We try to reuse those
1483  // instead of generating duplicate IR.
1484  ValueVector repl = mapping.lookupOrNull(value, value.getType());
1485  if (!repl.empty())
1486  return repl.front();
1487 
1488  // Check if the value is dead. No replacement value is needed in that case.
1489  // This is an approximate check that may have false negatives but does not
1490  // require computing and traversing an inverse mapping. (We may end up
1491  // building source materializations that are never used and that fold away.)
1492  if (llvm::all_of(value.getUsers(),
1493  [&](Operation *op) { return replacedOps.contains(op); }) &&
1494  !mapping.isMappedTo(value))
1495  return Value();
1496 
1497  // No replacement value was found. Get the latest replacement value
1498  // (regardless of the type) and build a source materialization to the
1499  // original type.
1500  repl = mapping.lookupOrNull(value);
1501  if (repl.empty()) {
1502  // No replacement value is registered in the mapping. This means that the
1503  // value is dropped and no longer needed. (If the value were still needed,
1504  // a source materialization producing a replacement value "out of thin air"
1505  // would have already been created during `replaceOp` or
1506  // `applySignatureConversion`.)
1507  return Value();
1508  }
1509 
1510  // Note: `computeInsertPoint` computes the "earliest" insertion point at
1511  // which all values in `repl` are defined. It is important to emit the
1512  // materialization at that location because the same materialization may be
1513  // reused in a different context. (That's because materializations are cached
1514  // in the conversion value mapping.) The insertion point of the
1515  // materialization must be valid for all future users that may be created
1516  // later in the conversion process.
1517  Value castValue =
1518  buildUnresolvedMaterialization(MaterializationKind::Source,
1519  computeInsertPoint(repl), value.getLoc(),
1520  /*valuesToMap=*/repl, /*inputs=*/repl,
1521  /*outputTypes=*/value.getType(),
1522  /*originalType=*/Type(), converter)
1523  .front();
1524  return castValue;
1525 }
1526 
1527 //===----------------------------------------------------------------------===//
1528 // Rewriter Notification Hooks
1529 //===----------------------------------------------------------------------===//
1530 
1532  Operation *op, OpBuilder::InsertPoint previous) {
1533  LLVM_DEBUG({
1534  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1535  << ")\n";
1536  });
1537  assert(!wasOpReplaced(op->getParentOp()) &&
1538  "attempting to insert into a block within a replaced/erased op");
1539 
1540  if (!previous.isSet()) {
1541  // This is a newly created op.
1542  appendRewrite<CreateOperationRewrite>(op);
1543  return;
1544  }
1545  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1546  ? nullptr
1547  : &*previous.getPoint();
1548  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1549 }
1550 
1552  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1553  assert(newValues.size() == op->getNumResults());
1554  assert(!ignoredOps.contains(op) && "operation was already replaced");
1555 
1556  // Check if replaced op is an unresolved materialization, i.e., an
1557  // unrealized_conversion_cast op that was created by the conversion driver.
1558  bool isUnresolvedMaterialization = false;
1559  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1560  if (unresolvedMaterializations.contains(castOp))
1561  isUnresolvedMaterialization = true;
1562 
1563  // Create mappings for each of the new result values.
1564  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
1565  if (repl.empty()) {
1566  // This result was dropped and no replacement value was provided.
1567  if (isUnresolvedMaterialization) {
1568  // Do not create another materializations if we are erasing a
1569  // materialization.
1570  continue;
1571  }
1572 
1573  // Materialize a replacement value "out of thin air".
1575  MaterializationKind::Source, computeInsertPoint(result),
1576  result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
1577  /*outputTypes=*/result.getType(), /*originalType=*/Type(),
1579  continue;
1580  } else {
1581  // Make sure that the user does not mess with unresolved materializations
1582  // that were inserted by the conversion driver. We keep track of these
1583  // ops in internal data structures. Erasing them must be allowed because
1584  // this can happen when the user is erasing an entire block (including
1585  // its body). But replacing them with another value should be forbidden
1586  // to avoid problems with the `mapping`.
1587  assert(!isUnresolvedMaterialization &&
1588  "attempting to replace an unresolved materialization");
1589  }
1590 
1591  // Remap result to replacement value.
1592  if (repl.empty())
1593  continue;
1594  mapping.map(static_cast<Value>(result), std::move(repl));
1595  }
1596 
1597  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1598  // Mark this operation and all nested ops as replaced.
1599  op->walk([&](Operation *op) { replacedOps.insert(op); });
1600 }
1601 
1603  appendRewrite<EraseBlockRewrite>(block);
1604 }
1605 
1607  Block *block, Region *previous, Region::iterator previousIt) {
1608  assert(!wasOpReplaced(block->getParentOp()) &&
1609  "attempting to insert into a region within a replaced/erased op");
1610  LLVM_DEBUG(
1611  {
1612  Operation *parent = block->getParentOp();
1613  if (parent) {
1614  logger.startLine() << "** Insert Block into : '" << parent->getName()
1615  << "'(" << parent << ")\n";
1616  } else {
1617  logger.startLine()
1618  << "** Insert Block into detached Region (nullptr parent op)'\n";
1619  }
1620  });
1621 
1622  if (!previous) {
1623  // This is a newly created block.
1624  appendRewrite<CreateBlockRewrite>(block);
1625  return;
1626  }
1627  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1628  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1629 }
1630 
1632  Block *block, Block *srcBlock, Block::iterator before) {
1633  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1634 }
1635 
1637  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1638  LLVM_DEBUG({
1640  reasonCallback(diag);
1641  logger.startLine() << "** Failure : " << diag.str() << "\n";
1642  if (config.notifyCallback)
1644  });
1645 }
1646 
1647 //===----------------------------------------------------------------------===//
1648 // ConversionPatternRewriter
1649 //===----------------------------------------------------------------------===//
1650 
1651 ConversionPatternRewriter::ConversionPatternRewriter(
1652  MLIRContext *ctx, const ConversionConfig &config)
1653  : PatternRewriter(ctx),
1654  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1655  setListener(impl.get());
1656 }
1657 
1659 
1661  assert(op && newOp && "expected non-null op");
1662  replaceOp(op, newOp->getResults());
1663 }
1664 
1666  assert(op->getNumResults() == newValues.size() &&
1667  "incorrect # of replacement values");
1668  LLVM_DEBUG({
1669  impl->logger.startLine()
1670  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1671  });
1673  llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
1674  return v ? SmallVector<Value>{v} : SmallVector<Value>();
1675  });
1676  impl->notifyOpReplaced(op, std::move(newVals));
1677 }
1678 
1680  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1681  assert(op->getNumResults() == newValues.size() &&
1682  "incorrect # of replacement values");
1683  LLVM_DEBUG({
1684  impl->logger.startLine()
1685  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1686  });
1687  impl->notifyOpReplaced(op, std::move(newValues));
1688 }
1689 
1691  LLVM_DEBUG({
1692  impl->logger.startLine()
1693  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1694  });
1695  SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
1696  impl->notifyOpReplaced(op, std::move(nullRepls));
1697 }
1698 
1700  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1701  "attempting to erase a block within a replaced/erased op");
1702 
1703  // Mark all ops for erasure.
1704  for (Operation &op : *block)
1705  eraseOp(&op);
1706 
1707  // Unlink the block from its parent region. The block is kept in the rewrite
1708  // object and will be actually destroyed when rewrites are applied. This
1709  // allows us to keep the operations in the block live and undo the removal by
1710  // re-inserting the block.
1711  impl->notifyBlockIsBeingErased(block);
1712  block->getParent()->getBlocks().remove(block);
1713 }
1714 
1716  Block *block, TypeConverter::SignatureConversion &conversion,
1717  const TypeConverter *converter) {
1718  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1719  "attempting to apply a signature conversion to a block within a "
1720  "replaced/erased op");
1721  return impl->applySignatureConversion(*this, block, converter, conversion);
1722 }
1723 
1725  Region *region, const TypeConverter &converter,
1726  TypeConverter::SignatureConversion *entryConversion) {
1727  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1728  "attempting to apply a signature conversion to a block within a "
1729  "replaced/erased op");
1730  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1731 }
1732 
1734  Value to) {
1735  LLVM_DEBUG({
1736  impl->logger.startLine() << "** Replace Argument : '" << from << "'";
1737  if (Operation *parentOp = from.getOwner()->getParentOp()) {
1738  impl->logger.getOStream() << " (in region of '" << parentOp->getName()
1739  << "' (" << parentOp << ")\n";
1740  } else {
1741  impl->logger.getOStream() << " (unlinked block)\n";
1742  }
1743  });
1744  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
1745  impl->currentTypeConverter);
1746  impl->mapping.map(from, to);
1747 }
1748 
1750  SmallVector<ValueVector> remappedValues;
1751  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1752  remappedValues)))
1753  return nullptr;
1754  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
1755  return remappedValues.front().front();
1756 }
1757 
1758 LogicalResult
1760  SmallVectorImpl<Value> &results) {
1761  if (keys.empty())
1762  return success();
1763  SmallVector<ValueVector> remapped;
1764  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1765  remapped)))
1766  return failure();
1767  for (const auto &values : remapped) {
1768  assert(values.size() == 1 && "1:N conversion not supported");
1769  results.push_back(values.front());
1770  }
1771  return success();
1772 }
1773 
1775  Block::iterator before,
1776  ValueRange argValues) {
1777 #ifndef NDEBUG
1778  assert(argValues.size() == source->getNumArguments() &&
1779  "incorrect # of argument replacement values");
1780  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1781  "attempting to inline a block from a replaced/erased op");
1782  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1783  "attempting to inline a block into a replaced/erased op");
1784  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1785  // The source block will be deleted, so it should not have any users (i.e.,
1786  // there should be no predecessors).
1787  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1788  "expected 'source' to have no predecessors");
1789 #endif // NDEBUG
1790 
1791  // If a listener is attached to the dialect conversion, ops cannot be moved
1792  // to the destination block in bulk ("fast path"). This is because at the time
1793  // the notifications are sent, it is unknown which ops were moved. Instead,
1794  // ops should be moved one-by-one ("slow path"), so that a separate
1795  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1796  // a bit more efficient, so we try to do that when possible.
1797  bool fastPath = !impl->config.listener;
1798 
1799  if (fastPath)
1800  impl->notifyBlockBeingInlined(dest, source, before);
1801 
1802  // Replace all uses of block arguments.
1803  for (auto it : llvm::zip(source->getArguments(), argValues))
1804  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1805 
1806  if (fastPath) {
1807  // Move all ops at once.
1808  dest->getOperations().splice(before, source->getOperations());
1809  } else {
1810  // Move op by op.
1811  while (!source->empty())
1812  moveOpBefore(&source->front(), dest, before);
1813  }
1814 
1815  // Erase the source block.
1816  eraseBlock(source);
1817 }
1818 
1820  assert(!impl->wasOpReplaced(op) &&
1821  "attempting to modify a replaced/erased op");
1822 #ifndef NDEBUG
1823  impl->pendingRootUpdates.insert(op);
1824 #endif
1825  impl->appendRewrite<ModifyOperationRewrite>(op);
1826 }
1827 
1829  assert(!impl->wasOpReplaced(op) &&
1830  "attempting to modify a replaced/erased op");
1832  // There is nothing to do here, we only need to track the operation at the
1833  // start of the update.
1834 #ifndef NDEBUG
1835  assert(impl->pendingRootUpdates.erase(op) &&
1836  "operation did not have a pending in-place update");
1837 #endif
1838 }
1839 
1841 #ifndef NDEBUG
1842  assert(impl->pendingRootUpdates.erase(op) &&
1843  "operation did not have a pending in-place update");
1844 #endif
1845  // Erase the last update for this operation.
1846  auto it = llvm::find_if(
1847  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1848  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1849  return modifyRewrite && modifyRewrite->getOperation() == op;
1850  });
1851  assert(it != impl->rewrites.rend() && "no root update started on op");
1852  (*it)->rollback();
1853  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1854  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1855 }
1856 
1858  return *impl;
1859 }
1860 
1861 //===----------------------------------------------------------------------===//
1862 // ConversionPattern
1863 //===----------------------------------------------------------------------===//
1864 
1866  ArrayRef<ValueRange> operands) const {
1867  SmallVector<Value> oneToOneOperands;
1868  oneToOneOperands.reserve(operands.size());
1869  for (ValueRange operand : operands) {
1870  if (operand.size() != 1)
1871  llvm::report_fatal_error("pattern '" + getDebugName() +
1872  "' does not support 1:N conversion");
1873  oneToOneOperands.push_back(operand.front());
1874  }
1875  return oneToOneOperands;
1876 }
1877 
1878 LogicalResult
1880  PatternRewriter &rewriter) const {
1881  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1882  auto &rewriterImpl = dialectRewriter.getImpl();
1883 
1884  // Track the current conversion pattern type converter in the rewriter.
1885  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1886  getTypeConverter());
1887 
1888  // Remap the operands of the operation.
1889  SmallVector<ValueVector> remapped;
1890  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1891  op->getOperands(), remapped))) {
1892  return failure();
1893  }
1894  SmallVector<ValueRange> remappedAsRange =
1895  llvm::to_vector_of<ValueRange>(remapped);
1896  return matchAndRewrite(op, remappedAsRange, dialectRewriter);
1897 }
1898 
1899 //===----------------------------------------------------------------------===//
1900 // OperationLegalizer
1901 //===----------------------------------------------------------------------===//
1902 
1903 namespace {
1904 /// A set of rewrite patterns that can be used to legalize a given operation.
1905 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1906 
1907 /// This class defines a recursive operation legalizer.
1908 class OperationLegalizer {
1909 public:
1910  using LegalizationAction = ConversionTarget::LegalizationAction;
1911 
1912  OperationLegalizer(const ConversionTarget &targetInfo,
1914  const ConversionConfig &config);
1915 
1916  /// Returns true if the given operation is known to be illegal on the target.
1917  bool isIllegal(Operation *op) const;
1918 
1919  /// Attempt to legalize the given operation. Returns success if the operation
1920  /// was legalized, failure otherwise.
1921  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1922 
1923  /// Returns the conversion target in use by the legalizer.
1924  const ConversionTarget &getTarget() { return target; }
1925 
1926 private:
1927  /// Attempt to legalize the given operation by folding it.
1928  LogicalResult legalizeWithFold(Operation *op,
1929  ConversionPatternRewriter &rewriter);
1930 
1931  /// Attempt to legalize the given operation by applying a pattern. Returns
1932  /// success if the operation was legalized, failure otherwise.
1933  LogicalResult legalizeWithPattern(Operation *op,
1934  ConversionPatternRewriter &rewriter);
1935 
1936  /// Return true if the given pattern may be applied to the given operation,
1937  /// false otherwise.
1938  bool canApplyPattern(Operation *op, const Pattern &pattern,
1939  ConversionPatternRewriter &rewriter);
1940 
1941  /// Legalize the resultant IR after successfully applying the given pattern.
1942  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1943  ConversionPatternRewriter &rewriter,
1944  RewriterState &curState);
1945 
1946  /// Legalizes the actions registered during the execution of a pattern.
1947  LogicalResult
1948  legalizePatternBlockRewrites(Operation *op,
1949  ConversionPatternRewriter &rewriter,
1951  RewriterState &state, RewriterState &newState);
1952  LogicalResult legalizePatternCreatedOperations(
1954  RewriterState &state, RewriterState &newState);
1955  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1957  RewriterState &state,
1958  RewriterState &newState);
1959 
1960  //===--------------------------------------------------------------------===//
1961  // Cost Model
1962  //===--------------------------------------------------------------------===//
1963 
1964  /// Build an optimistic legalization graph given the provided patterns. This
1965  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1966  /// patterns for operations that are not directly legal, but may be
1967  /// transitively legal for the current target given the provided patterns.
1968  void buildLegalizationGraph(
1969  LegalizationPatterns &anyOpLegalizerPatterns,
1971 
1972  /// Compute the benefit of each node within the computed legalization graph.
1973  /// This orders the patterns within 'legalizerPatterns' based upon two
1974  /// criteria:
1975  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1976  /// represent the more direct mapping to the target.
1977  /// 2) When comparing patterns with the same legalization depth, prefer the
1978  /// pattern with the highest PatternBenefit. This allows for users to
1979  /// prefer specific legalizations over others.
1980  void computeLegalizationGraphBenefit(
1981  LegalizationPatterns &anyOpLegalizerPatterns,
1983 
1984  /// Compute the legalization depth when legalizing an operation of the given
1985  /// type.
1986  unsigned computeOpLegalizationDepth(
1987  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1989 
1990  /// Apply the conversion cost model to the given set of patterns, and return
1991  /// the smallest legalization depth of any of the patterns. See
1992  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1993  unsigned applyCostModelToPatterns(
1994  LegalizationPatterns &patterns,
1995  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1997 
1998  /// The current set of patterns that have been applied.
1999  SmallPtrSet<const Pattern *, 8> appliedPatterns;
2000 
2001  /// The legalization information provided by the target.
2002  const ConversionTarget &target;
2003 
2004  /// The pattern applicator to use for conversions.
2005  PatternApplicator applicator;
2006 
2007  /// Dialect conversion configuration.
2008  const ConversionConfig &config;
2009 };
2010 } // namespace
2011 
2012 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
2014  const ConversionConfig &config)
2015  : target(targetInfo), applicator(patterns), config(config) {
2016  // The set of patterns that can be applied to illegal operations to transform
2017  // them into legal ones.
2019  LegalizationPatterns anyOpLegalizerPatterns;
2020 
2021  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2022  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2023 }
2024 
2025 bool OperationLegalizer::isIllegal(Operation *op) const {
2026  return target.isIllegal(op);
2027 }
2028 
2029 LogicalResult
2030 OperationLegalizer::legalize(Operation *op,
2031  ConversionPatternRewriter &rewriter) {
2032 #ifndef NDEBUG
2033  const char *logLineComment =
2034  "//===-------------------------------------------===//\n";
2035 
2036  auto &logger = rewriter.getImpl().logger;
2037 #endif
2038  LLVM_DEBUG({
2039  logger.getOStream() << "\n";
2040  logger.startLine() << logLineComment;
2041  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
2042  << op << ") {\n";
2043  logger.indent();
2044 
2045  // If the operation has no regions, just print it here.
2046  if (op->getNumRegions() == 0) {
2047  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2048  logger.getOStream() << "\n\n";
2049  }
2050  });
2051 
2052  // Check if this operation is legal on the target.
2053  if (auto legalityInfo = target.isLegal(op)) {
2054  LLVM_DEBUG({
2055  logSuccess(
2056  logger, "operation marked legal by the target{0}",
2057  legalityInfo->isRecursivelyLegal
2058  ? "; NOTE: operation is recursively legal; skipping internals"
2059  : "");
2060  logger.startLine() << logLineComment;
2061  });
2062 
2063  // If this operation is recursively legal, mark its children as ignored so
2064  // that we don't consider them for legalization.
2065  if (legalityInfo->isRecursivelyLegal) {
2066  op->walk([&](Operation *nested) {
2067  if (op != nested)
2068  rewriter.getImpl().ignoredOps.insert(nested);
2069  });
2070  }
2071 
2072  return success();
2073  }
2074 
2075  // Check to see if the operation is ignored and doesn't need to be converted.
2076  if (rewriter.getImpl().isOpIgnored(op)) {
2077  LLVM_DEBUG({
2078  logSuccess(logger, "operation marked 'ignored' during conversion");
2079  logger.startLine() << logLineComment;
2080  });
2081  return success();
2082  }
2083 
2084  // If the operation isn't legal, try to fold it in-place.
2085  // TODO: Should we always try to do this, even if the op is
2086  // already legal?
2087  if (succeeded(legalizeWithFold(op, rewriter))) {
2088  LLVM_DEBUG({
2089  logSuccess(logger, "operation was folded");
2090  logger.startLine() << logLineComment;
2091  });
2092  return success();
2093  }
2094 
2095  // Otherwise, we need to apply a legalization pattern to this operation.
2096  if (succeeded(legalizeWithPattern(op, rewriter))) {
2097  LLVM_DEBUG({
2098  logSuccess(logger, "");
2099  logger.startLine() << logLineComment;
2100  });
2101  return success();
2102  }
2103 
2104  LLVM_DEBUG({
2105  logFailure(logger, "no matched legalization pattern");
2106  logger.startLine() << logLineComment;
2107  });
2108  return failure();
2109 }
2110 
2111 LogicalResult
2112 OperationLegalizer::legalizeWithFold(Operation *op,
2113  ConversionPatternRewriter &rewriter) {
2114  auto &rewriterImpl = rewriter.getImpl();
2115  LLVM_DEBUG({
2116  rewriterImpl.logger.startLine() << "* Fold {\n";
2117  rewriterImpl.logger.indent();
2118  });
2119  (void)rewriterImpl;
2120 
2121  // Try to fold the operation.
2122  SmallVector<Value, 2> replacementValues;
2123  SmallVector<Operation *, 2> newOps;
2124  rewriter.setInsertionPoint(op);
2125  if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2126  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2127  return failure();
2128  }
2129 
2130  // An empty list of replacement values indicates that the fold was in-place.
2131  // As the operation changed, a new legalization needs to be attempted.
2132  if (replacementValues.empty())
2133  return legalize(op, rewriter);
2134 
2135  // Recursively legalize any new constant operations.
2136  for (Operation *newOp : newOps) {
2137  if (failed(legalize(newOp, rewriter))) {
2138  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2139  "failed to legalize generated constant '{0}'",
2140  newOp->getName()));
2141  // Legalization failed: erase all materialized constants.
2142  for (Operation *op : newOps)
2143  rewriter.eraseOp(op);
2144  return failure();
2145  }
2146  }
2147 
2148  // Insert a replacement for 'op' with the folded replacement values.
2149  rewriter.replaceOp(op, replacementValues);
2150 
2151  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2152  return success();
2153 }
2154 
2155 LogicalResult
2156 OperationLegalizer::legalizeWithPattern(Operation *op,
2157  ConversionPatternRewriter &rewriter) {
2158  auto &rewriterImpl = rewriter.getImpl();
2159 
2160  // Functor that returns if the given pattern may be applied.
2161  auto canApply = [&](const Pattern &pattern) {
2162  bool canApply = canApplyPattern(op, pattern, rewriter);
2163  if (canApply && config.listener)
2164  config.listener->notifyPatternBegin(pattern, op);
2165  return canApply;
2166  };
2167 
2168  // Functor that cleans up the rewriter state after a pattern failed to match.
2169  RewriterState curState = rewriterImpl.getCurrentState();
2170  auto onFailure = [&](const Pattern &pattern) {
2171  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2172  LLVM_DEBUG({
2173  logFailure(rewriterImpl.logger, "pattern failed to match");
2174  if (rewriterImpl.config.notifyCallback) {
2176  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2177  << "\" on op:\n"
2178  << *op;
2179  rewriterImpl.config.notifyCallback(diag);
2180  }
2181  });
2182  if (config.listener)
2183  config.listener->notifyPatternEnd(pattern, failure());
2184  rewriterImpl.resetState(curState, pattern.getDebugName());
2185  appliedPatterns.erase(&pattern);
2186  };
2187 
2188  // Functor that performs additional legalization when a pattern is
2189  // successfully applied.
2190  auto onSuccess = [&](const Pattern &pattern) {
2191  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2192  auto result = legalizePatternResult(op, pattern, rewriter, curState);
2193  appliedPatterns.erase(&pattern);
2194  if (failed(result)) {
2195  if (!rewriterImpl.config.allowPatternRollback)
2196  op->emitError("pattern '")
2197  << pattern.getDebugName()
2198  << "' produced IR that could not be legalized";
2199  rewriterImpl.resetState(curState, pattern.getDebugName());
2200  }
2201  if (config.listener)
2202  config.listener->notifyPatternEnd(pattern, result);
2203  return result;
2204  };
2205 
2206  // Try to match and rewrite a pattern on this operation.
2207  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2208  onSuccess);
2209 }
2210 
2211 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2212  ConversionPatternRewriter &rewriter) {
2213  LLVM_DEBUG({
2214  auto &os = rewriter.getImpl().logger;
2215  os.getOStream() << "\n";
2216  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2217  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2218  os.getOStream() << ")' {\n";
2219  os.indent();
2220  });
2221 
2222  // Ensure that we don't cycle by not allowing the same pattern to be
2223  // applied twice in the same recursion stack if it is not known to be safe.
2224  if (!pattern.hasBoundedRewriteRecursion() &&
2225  !appliedPatterns.insert(&pattern).second) {
2226  LLVM_DEBUG(
2227  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2228  return false;
2229  }
2230  return true;
2231 }
2232 
2233 LogicalResult
2234 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2235  ConversionPatternRewriter &rewriter,
2236  RewriterState &curState) {
2237  auto &impl = rewriter.getImpl();
2238  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2239 
2240 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2241  // Check that the root was either replaced or updated in place.
2242  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2243  auto replacedRoot = [&] {
2244  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2245  };
2246  auto updatedRootInPlace = [&] {
2247  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2248  };
2249  if (!replacedRoot() && !updatedRootInPlace())
2250  llvm::report_fatal_error("expected pattern to replace the root operation");
2251 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2252 
2253  // Legalize each of the actions registered during application.
2254  RewriterState newState = impl.getCurrentState();
2255  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2256  newState)) ||
2257  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2258  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2259  newState))) {
2260  return failure();
2261  }
2262 
2263  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2264  return success();
2265 }
2266 
2267 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2268  Operation *op, ConversionPatternRewriter &rewriter,
2269  ConversionPatternRewriterImpl &impl, RewriterState &state,
2270  RewriterState &newState) {
2271  SmallPtrSet<Operation *, 16> operationsToIgnore;
2272 
2273  // If the pattern moved or created any blocks, make sure the types of block
2274  // arguments get legalized.
2275  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2276  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2277  if (!rewrite)
2278  continue;
2279  Block *block = rewrite->getBlock();
2280  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2281  ReplaceBlockArgRewrite>(rewrite))
2282  continue;
2283  // Only check blocks outside of the current operation.
2284  Operation *parentOp = block->getParentOp();
2285  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2286  continue;
2287 
2288  // If the region of the block has a type converter, try to convert the block
2289  // directly.
2290  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2291  std::optional<TypeConverter::SignatureConversion> conversion =
2292  converter->convertBlockSignature(block);
2293  if (!conversion) {
2294  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2295  "block"));
2296  return failure();
2297  }
2298  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2299  continue;
2300  }
2301 
2302  // Otherwise, check that this operation isn't one generated by this pattern.
2303  // This is because we will attempt to legalize the parent operation, and
2304  // blocks in regions created by this pattern will already be legalized later
2305  // on. If we haven't built the set yet, build it now.
2306  if (operationsToIgnore.empty()) {
2307  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2308  ++i) {
2309  auto *createOp =
2310  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2311  if (!createOp)
2312  continue;
2313  operationsToIgnore.insert(createOp->getOperation());
2314  }
2315  }
2316 
2317  // If this operation should be considered for re-legalization, try it.
2318  if (operationsToIgnore.insert(parentOp).second &&
2319  failed(legalize(parentOp, rewriter))) {
2320  LLVM_DEBUG(logFailure(impl.logger,
2321  "operation '{0}'({1}) became illegal after rewrite",
2322  parentOp->getName(), parentOp));
2323  return failure();
2324  }
2325  }
2326  return success();
2327 }
2328 
2329 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2331  RewriterState &state, RewriterState &newState) {
2332  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2333  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2334  if (!createOp)
2335  continue;
2336  Operation *op = createOp->getOperation();
2337  if (failed(legalize(op, rewriter))) {
2338  LLVM_DEBUG(logFailure(impl.logger,
2339  "failed to legalize generated operation '{0}'({1})",
2340  op->getName(), op));
2341  return failure();
2342  }
2343  }
2344  return success();
2345 }
2346 
2347 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2349  RewriterState &state, RewriterState &newState) {
2350  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2351  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2352  if (!rewrite)
2353  continue;
2354  Operation *op = rewrite->getOperation();
2355  if (failed(legalize(op, rewriter))) {
2356  LLVM_DEBUG(logFailure(
2357  impl.logger, "failed to legalize operation updated in-place '{0}'",
2358  op->getName()));
2359  return failure();
2360  }
2361  }
2362  return success();
2363 }
2364 
2365 //===----------------------------------------------------------------------===//
2366 // Cost Model
2367 //===----------------------------------------------------------------------===//
2368 
2369 void OperationLegalizer::buildLegalizationGraph(
2370  LegalizationPatterns &anyOpLegalizerPatterns,
2371  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2372  // A mapping between an operation and a set of operations that can be used to
2373  // generate it.
2375  // A mapping between an operation and any currently invalid patterns it has.
2377  // A worklist of patterns to consider for legality.
2378  SetVector<const Pattern *> patternWorklist;
2379 
2380  // Build the mapping from operations to the parent ops that may generate them.
2381  applicator.walkAllPatterns([&](const Pattern &pattern) {
2382  std::optional<OperationName> root = pattern.getRootKind();
2383 
2384  // If the pattern has no specific root, we can't analyze the relationship
2385  // between the root op and generated operations. Given that, add all such
2386  // patterns to the legalization set.
2387  if (!root) {
2388  anyOpLegalizerPatterns.push_back(&pattern);
2389  return;
2390  }
2391 
2392  // Skip operations that are always known to be legal.
2393  if (target.getOpAction(*root) == LegalizationAction::Legal)
2394  return;
2395 
2396  // Add this pattern to the invalid set for the root op and record this root
2397  // as a parent for any generated operations.
2398  invalidPatterns[*root].insert(&pattern);
2399  for (auto op : pattern.getGeneratedOps())
2400  parentOps[op].insert(*root);
2401 
2402  // Add this pattern to the worklist.
2403  patternWorklist.insert(&pattern);
2404  });
2405 
2406  // If there are any patterns that don't have a specific root kind, we can't
2407  // make direct assumptions about what operations will never be legalized.
2408  // Note: Technically we could, but it would require an analysis that may
2409  // recurse into itself. It would be better to perform this kind of filtering
2410  // at a higher level than here anyways.
2411  if (!anyOpLegalizerPatterns.empty()) {
2412  for (const Pattern *pattern : patternWorklist)
2413  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2414  return;
2415  }
2416 
2417  while (!patternWorklist.empty()) {
2418  auto *pattern = patternWorklist.pop_back_val();
2419 
2420  // Check to see if any of the generated operations are invalid.
2421  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2422  std::optional<LegalizationAction> action = target.getOpAction(op);
2423  return !legalizerPatterns.count(op) &&
2424  (!action || action == LegalizationAction::Illegal);
2425  }))
2426  continue;
2427 
2428  // Otherwise, if all of the generated operation are valid, this op is now
2429  // legal so add all of the child patterns to the worklist.
2430  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2431  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2432 
2433  // Add any invalid patterns of the parent operations to see if they have now
2434  // become legal.
2435  for (auto op : parentOps[*pattern->getRootKind()])
2436  patternWorklist.set_union(invalidPatterns[op]);
2437  }
2438 }
2439 
2440 void OperationLegalizer::computeLegalizationGraphBenefit(
2441  LegalizationPatterns &anyOpLegalizerPatterns,
2442  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2443  // The smallest pattern depth, when legalizing an operation.
2444  DenseMap<OperationName, unsigned> minOpPatternDepth;
2445 
2446  // For each operation that is transitively legal, compute a cost for it.
2447  for (auto &opIt : legalizerPatterns)
2448  if (!minOpPatternDepth.count(opIt.first))
2449  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2450  legalizerPatterns);
2451 
2452  // Apply the cost model to the patterns that can match any operation. Those
2453  // with a specific operation type are already resolved when computing the op
2454  // legalization depth.
2455  if (!anyOpLegalizerPatterns.empty())
2456  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2457  legalizerPatterns);
2458 
2459  // Apply a cost model to the pattern applicator. We order patterns first by
2460  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2461  // decreasing benefit.
2462  applicator.applyCostModel([&](const Pattern &pattern) {
2463  ArrayRef<const Pattern *> orderedPatternList;
2464  if (std::optional<OperationName> rootName = pattern.getRootKind())
2465  orderedPatternList = legalizerPatterns[*rootName];
2466  else
2467  orderedPatternList = anyOpLegalizerPatterns;
2468 
2469  // If the pattern is not found, then it was removed and cannot be matched.
2470  auto *it = llvm::find(orderedPatternList, &pattern);
2471  if (it == orderedPatternList.end())
2473 
2474  // Patterns found earlier in the list have higher benefit.
2475  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2476  });
2477 }
2478 
2479 unsigned OperationLegalizer::computeOpLegalizationDepth(
2480  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2481  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2482  // Check for existing depth.
2483  auto depthIt = minOpPatternDepth.find(op);
2484  if (depthIt != minOpPatternDepth.end())
2485  return depthIt->second;
2486 
2487  // If a mapping for this operation does not exist, then this operation
2488  // is always legal. Return 0 as the depth for a directly legal operation.
2489  auto opPatternsIt = legalizerPatterns.find(op);
2490  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2491  return 0u;
2492 
2493  // Record this initial depth in case we encounter this op again when
2494  // recursively computing the depth.
2495  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2496 
2497  // Apply the cost model to the operation patterns, and update the minimum
2498  // depth.
2499  unsigned minDepth = applyCostModelToPatterns(
2500  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2501  minOpPatternDepth[op] = minDepth;
2502  return minDepth;
2503 }
2504 
2505 unsigned OperationLegalizer::applyCostModelToPatterns(
2506  LegalizationPatterns &patterns,
2507  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2508  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2509  unsigned minDepth = std::numeric_limits<unsigned>::max();
2510 
2511  // Compute the depth for each pattern within the set.
2512  SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2513  patternsByDepth.reserve(patterns.size());
2514  for (const Pattern *pattern : patterns) {
2515  unsigned depth = 1;
2516  for (auto generatedOp : pattern->getGeneratedOps()) {
2517  unsigned generatedOpDepth = computeOpLegalizationDepth(
2518  generatedOp, minOpPatternDepth, legalizerPatterns);
2519  depth = std::max(depth, generatedOpDepth + 1);
2520  }
2521  patternsByDepth.emplace_back(pattern, depth);
2522 
2523  // Update the minimum depth of the pattern list.
2524  minDepth = std::min(minDepth, depth);
2525  }
2526 
2527  // If the operation only has one legalization pattern, there is no need to
2528  // sort them.
2529  if (patternsByDepth.size() == 1)
2530  return minDepth;
2531 
2532  // Sort the patterns by those likely to be the most beneficial.
2533  llvm::stable_sort(patternsByDepth,
2534  [](const std::pair<const Pattern *, unsigned> &lhs,
2535  const std::pair<const Pattern *, unsigned> &rhs) {
2536  // First sort by the smaller pattern legalization
2537  // depth.
2538  if (lhs.second != rhs.second)
2539  return lhs.second < rhs.second;
2540 
2541  // Then sort by the larger pattern benefit.
2542  auto lhsBenefit = lhs.first->getBenefit();
2543  auto rhsBenefit = rhs.first->getBenefit();
2544  return lhsBenefit > rhsBenefit;
2545  });
2546 
2547  // Update the legalization pattern to use the new sorted list.
2548  patterns.clear();
2549  for (auto &patternIt : patternsByDepth)
2550  patterns.push_back(patternIt.first);
2551  return minDepth;
2552 }
2553 
2554 //===----------------------------------------------------------------------===//
2555 // OperationConverter
2556 //===----------------------------------------------------------------------===//
2557 namespace {
2558 enum OpConversionMode {
2559  /// In this mode, the conversion will ignore failed conversions to allow
2560  /// illegal operations to co-exist in the IR.
2561  Partial,
2562 
2563  /// In this mode, all operations must be legal for the given target for the
2564  /// conversion to succeed.
2565  Full,
2566 
2567  /// In this mode, operations are analyzed for legality. No actual rewrites are
2568  /// applied to the operations on success.
2569  Analysis,
2570 };
2571 } // namespace
2572 
2573 namespace mlir {
2574 // This class converts operations to a given conversion target via a set of
2575 // rewrite patterns. The conversion behaves differently depending on the
2576 // conversion mode.
2578  explicit OperationConverter(const ConversionTarget &target,
2580  const ConversionConfig &config,
2581  OpConversionMode mode)
2582  : config(config), opLegalizer(target, patterns, this->config),
2583  mode(mode) {}
2584 
2585  /// Converts the given operations to the conversion target.
2586  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2587 
2588 private:
2589  /// Converts an operation with the given rewriter.
2590  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2591 
2592  /// Dialect conversion configuration.
2593  ConversionConfig config;
2594 
2595  /// The legalizer to use when converting operations.
2596  OperationLegalizer opLegalizer;
2597 
2598  /// The conversion mode to use when legalizing operations.
2599  OpConversionMode mode;
2600 };
2601 } // namespace mlir
2602 
2603 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2604  Operation *op) {
2605  // Legalize the given operation.
2606  if (failed(opLegalizer.legalize(op, rewriter))) {
2607  // Handle the case of a failed conversion for each of the different modes.
2608  // Full conversions expect all operations to be converted.
2609  if (mode == OpConversionMode::Full)
2610  return op->emitError()
2611  << "failed to legalize operation '" << op->getName() << "'";
2612  // Partial conversions allow conversions to fail iff the operation was not
2613  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2614  // set, non-legalizable ops are added to that set.
2615  if (mode == OpConversionMode::Partial) {
2616  if (opLegalizer.isIllegal(op))
2617  return op->emitError()
2618  << "failed to legalize operation '" << op->getName()
2619  << "' that was explicitly marked illegal";
2620  if (config.unlegalizedOps)
2621  config.unlegalizedOps->insert(op);
2622  }
2623  } else if (mode == OpConversionMode::Analysis) {
2624  // Analysis conversions don't fail if any operations fail to legalize,
2625  // they are only interested in the operations that were successfully
2626  // legalized.
2627  if (config.legalizableOps)
2628  config.legalizableOps->insert(op);
2629  }
2630  return success();
2631 }
2632 
2633 static LogicalResult
2635  UnresolvedMaterializationRewrite *rewrite) {
2636  UnrealizedConversionCastOp op = rewrite->getOperation();
2637  assert(!op.use_empty() &&
2638  "expected that dead materializations have already been DCE'd");
2639  Operation::operand_range inputOperands = op.getOperands();
2640 
2641  // Try to materialize the conversion.
2642  if (const TypeConverter *converter = rewrite->getConverter()) {
2643  rewriter.setInsertionPoint(op);
2644  SmallVector<Value> newMaterialization;
2645  switch (rewrite->getMaterializationKind()) {
2646  case MaterializationKind::Target:
2647  newMaterialization = converter->materializeTargetConversion(
2648  rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2649  rewrite->getOriginalType());
2650  break;
2651  case MaterializationKind::Source:
2652  assert(op->getNumResults() == 1 && "expected single result");
2653  Value sourceMat = converter->materializeSourceConversion(
2654  rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2655  if (sourceMat)
2656  newMaterialization.push_back(sourceMat);
2657  break;
2658  }
2659  if (!newMaterialization.empty()) {
2660 #ifndef NDEBUG
2661  ValueRange newMaterializationRange(newMaterialization);
2662  assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
2663  "materialization callback produced value of incorrect type");
2664 #endif // NDEBUG
2665  rewriter.replaceOp(op, newMaterialization);
2666  return success();
2667  }
2668  }
2669 
2670  InFlightDiagnostic diag = op->emitError()
2671  << "failed to legalize unresolved materialization "
2672  "from ("
2673  << inputOperands.getTypes() << ") to ("
2674  << op.getResultTypes()
2675  << ") that remained live after conversion";
2676  diag.attachNote(op->getUsers().begin()->getLoc())
2677  << "see existing live user here: " << *op->getUsers().begin();
2678  return failure();
2679 }
2680 
2682  if (ops.empty())
2683  return success();
2684  const ConversionTarget &target = opLegalizer.getTarget();
2685 
2686  // Compute the set of operations and blocks to convert.
2687  SmallVector<Operation *> toConvert;
2688  for (auto *op : ops) {
2690  [&](Operation *op) {
2691  toConvert.push_back(op);
2692  // Don't check this operation's children for conversion if the
2693  // operation is recursively legal.
2694  auto legalityInfo = target.isLegal(op);
2695  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2696  return WalkResult::skip();
2697  return WalkResult::advance();
2698  });
2699  }
2700 
2701  // Convert each operation and discard rewrites on failure.
2702  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2703  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2704 
2705  for (auto *op : toConvert) {
2706  if (failed(convert(rewriter, op))) {
2707  // Dialect conversion failed.
2708  if (rewriterImpl.config.allowPatternRollback) {
2709  // Rollback is allowed: restore the original IR.
2710  rewriterImpl.undoRewrites();
2711  } else {
2712  // Rollback is not allowed: apply all modifications that have been
2713  // performed so far.
2714  rewriterImpl.applyRewrites();
2715  }
2716  return failure();
2717  }
2718  }
2719 
2720  // After a successful conversion, apply rewrites.
2721  rewriterImpl.applyRewrites();
2722 
2723  // Gather all unresolved materializations.
2726  &materializations = rewriterImpl.unresolvedMaterializations;
2727  for (auto it : materializations)
2728  allCastOps.push_back(it.first);
2729 
2730  // Reconcile all UnrealizedConversionCastOps that were inserted by the
2731  // dialect conversion frameworks. (Not the one that were inserted by
2732  // patterns.)
2733  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2734  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2735 
2736  // Try to legalize all unresolved materializations.
2737  if (config.buildMaterializations) {
2738  IRRewriter rewriter(rewriterImpl.context, config.listener);
2739  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2740  auto it = materializations.find(castOp);
2741  assert(it != materializations.end() && "inconsistent state");
2742  if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2743  return failure();
2744  }
2745  }
2746 
2747  return success();
2748 }
2749 
2750 //===----------------------------------------------------------------------===//
2751 // Reconcile Unrealized Casts
2752 //===----------------------------------------------------------------------===//
2753 
2756  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2757  SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
2758  // This set is maintained only if `remainingCastOps` is provided.
2759  DenseSet<Operation *> erasedOps;
2760 
2761  // Helper function that adds all operands to the worklist that are an
2762  // unrealized_conversion_cast op result.
2763  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2764  for (Value v : castOp.getInputs())
2765  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2766  worklist.insert(inputCastOp);
2767  };
2768 
2769  // Helper function that return the unrealized_conversion_cast op that
2770  // defines all inputs of the given op (in the same order). Return "nullptr"
2771  // if there is no such op.
2772  auto getInputCast =
2773  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2774  if (castOp.getInputs().empty())
2775  return {};
2776  auto inputCastOp =
2777  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2778  if (!inputCastOp)
2779  return {};
2780  if (inputCastOp.getOutputs() != castOp.getInputs())
2781  return {};
2782  return inputCastOp;
2783  };
2784 
2785  // Process ops in the worklist bottom-to-top.
2786  while (!worklist.empty()) {
2787  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2788  if (castOp->use_empty()) {
2789  // DCE: If the op has no users, erase it. Add the operands to the
2790  // worklist to find additional DCE opportunities.
2791  enqueueOperands(castOp);
2792  if (remainingCastOps)
2793  erasedOps.insert(castOp.getOperation());
2794  castOp->erase();
2795  continue;
2796  }
2797 
2798  // Traverse the chain of input cast ops to see if an op with the same
2799  // input types can be found.
2800  UnrealizedConversionCastOp nextCast = castOp;
2801  while (nextCast) {
2802  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2803  // Found a cast where the input types match the output types of the
2804  // matched op. We can directly use those inputs and the matched op can
2805  // be removed.
2806  enqueueOperands(castOp);
2807  castOp.replaceAllUsesWith(nextCast.getInputs());
2808  if (remainingCastOps)
2809  erasedOps.insert(castOp.getOperation());
2810  castOp->erase();
2811  break;
2812  }
2813  nextCast = getInputCast(nextCast);
2814  }
2815  }
2816 
2817  if (remainingCastOps)
2818  for (UnrealizedConversionCastOp op : castOps)
2819  if (!erasedOps.contains(op.getOperation()))
2820  remainingCastOps->push_back(op);
2821 }
2822 
2823 //===----------------------------------------------------------------------===//
2824 // Type Conversion
2825 //===----------------------------------------------------------------------===//
2826 
2828  ArrayRef<Type> types) {
2829  assert(!types.empty() && "expected valid types");
2830  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2831  addInputs(types);
2832 }
2833 
2835  assert(!types.empty() &&
2836  "1->0 type remappings don't need to be added explicitly");
2837  argTypes.append(types.begin(), types.end());
2838 }
2839 
2840 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2841  unsigned newInputNo,
2842  unsigned newInputCount) {
2843  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2844  assert(newInputCount != 0 && "expected valid input count");
2845  remappedInputs[origInputNo] =
2846  InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
2847 }
2848 
2850  unsigned origInputNo, ArrayRef<Value> replacements) {
2851  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2852  remappedInputs[origInputNo] = InputMapping{
2853  origInputNo, /*size=*/0,
2854  SmallVector<Value, 1>(replacements.begin(), replacements.end())};
2855 }
2856 
2858  SmallVectorImpl<Type> &results) const {
2859  assert(t && "expected non-null type");
2860 
2861  {
2862  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2863  std::defer_lock);
2865  cacheReadLock.lock();
2866  auto existingIt = cachedDirectConversions.find(t);
2867  if (existingIt != cachedDirectConversions.end()) {
2868  if (existingIt->second)
2869  results.push_back(existingIt->second);
2870  return success(existingIt->second != nullptr);
2871  }
2872  auto multiIt = cachedMultiConversions.find(t);
2873  if (multiIt != cachedMultiConversions.end()) {
2874  results.append(multiIt->second.begin(), multiIt->second.end());
2875  return success();
2876  }
2877  }
2878  // Walk the added converters in reverse order to apply the most recently
2879  // registered first.
2880  size_t currentCount = results.size();
2881 
2882  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2883  std::defer_lock);
2884 
2885  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2886  if (std::optional<LogicalResult> result = converter(t, results)) {
2888  cacheWriteLock.lock();
2889  if (!succeeded(*result)) {
2890  assert(results.size() == currentCount &&
2891  "failed type conversion should not change results");
2892  cachedDirectConversions.try_emplace(t, nullptr);
2893  return failure();
2894  }
2895  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2896  if (newTypes.size() == 1)
2897  cachedDirectConversions.try_emplace(t, newTypes.front());
2898  else
2899  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2900  return success();
2901  } else {
2902  assert(results.size() == currentCount &&
2903  "failed type conversion should not change results");
2904  }
2905  }
2906  return failure();
2907 }
2908 
2910  // Use the multi-type result version to convert the type.
2911  SmallVector<Type, 1> results;
2912  if (failed(convertType(t, results)))
2913  return nullptr;
2914 
2915  // Check to ensure that only one type was produced.
2916  return results.size() == 1 ? results.front() : nullptr;
2917 }
2918 
2919 LogicalResult
2921  SmallVectorImpl<Type> &results) const {
2922  for (Type type : types)
2923  if (failed(convertType(type, results)))
2924  return failure();
2925  return success();
2926 }
2927 
2928 bool TypeConverter::isLegal(Type type) const {
2929  return convertType(type) == type;
2930 }
2932  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2933 }
2934 
2935 bool TypeConverter::isLegal(Region *region) const {
2936  return llvm::all_of(*region, [this](Block &block) {
2937  return isLegal(block.getArgumentTypes());
2938  });
2939 }
2940 
2941 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2942  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2943 }
2944 
2945 LogicalResult
2947  SignatureConversion &result) const {
2948  // Try to convert the given input type.
2949  SmallVector<Type, 1> convertedTypes;
2950  if (failed(convertType(type, convertedTypes)))
2951  return failure();
2952 
2953  // If this argument is being dropped, there is nothing left to do.
2954  if (convertedTypes.empty())
2955  return success();
2956 
2957  // Otherwise, add the new inputs.
2958  result.addInputs(inputNo, convertedTypes);
2959  return success();
2960 }
2961 LogicalResult
2963  SignatureConversion &result,
2964  unsigned origInputOffset) const {
2965  for (unsigned i = 0, e = types.size(); i != e; ++i)
2966  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2967  return failure();
2968  return success();
2969 }
2970 
2972  Location loc, Type resultType,
2973  ValueRange inputs) const {
2974  for (const SourceMaterializationCallbackFn &fn :
2975  llvm::reverse(sourceMaterializations))
2976  if (Value result = fn(builder, resultType, inputs, loc))
2977  return result;
2978  return nullptr;
2979 }
2980 
2982  Location loc, Type resultType,
2983  ValueRange inputs,
2984  Type originalType) const {
2986  builder, loc, TypeRange(resultType), inputs, originalType);
2987  if (result.empty())
2988  return nullptr;
2989  assert(result.size() == 1 && "expected single result");
2990  return result.front();
2991 }
2992 
2994  OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
2995  Type originalType) const {
2996  for (const TargetMaterializationCallbackFn &fn :
2997  llvm::reverse(targetMaterializations)) {
2998  SmallVector<Value> result =
2999  fn(builder, resultTypes, inputs, loc, originalType);
3000  if (result.empty())
3001  continue;
3002  assert(TypeRange(ValueRange(result)) == resultTypes &&
3003  "callback produced incorrect number of values or values with "
3004  "incorrect types");
3005  return result;
3006  }
3007  return {};
3008 }
3009 
3010 std::optional<TypeConverter::SignatureConversion>
3012  SignatureConversion conversion(block->getNumArguments());
3013  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3014  return std::nullopt;
3015  return conversion;
3016 }
3017 
3018 //===----------------------------------------------------------------------===//
3019 // Type attribute conversion
3020 //===----------------------------------------------------------------------===//
3023  return AttributeConversionResult(attr, resultTag);
3024 }
3025 
3028  return AttributeConversionResult(nullptr, naTag);
3029 }
3030 
3033  return AttributeConversionResult(nullptr, abortTag);
3034 }
3035 
3037  return impl.getInt() == resultTag;
3038 }
3039 
3041  return impl.getInt() == naTag;
3042 }
3043 
3045  return impl.getInt() == abortTag;
3046 }
3047 
3049  assert(hasResult() && "Cannot get result from N/A or abort");
3050  return impl.getPointer();
3051 }
3052 
3053 std::optional<Attribute>
3055  for (const TypeAttributeConversionCallbackFn &fn :
3056  llvm::reverse(typeAttributeConversions)) {
3057  AttributeConversionResult res = fn(type, attr);
3058  if (res.hasResult())
3059  return res.getResult();
3060  if (res.isAbort())
3061  return std::nullopt;
3062  }
3063  return std::nullopt;
3064 }
3065 
3066 //===----------------------------------------------------------------------===//
3067 // FunctionOpInterfaceSignatureConversion
3068 //===----------------------------------------------------------------------===//
3069 
3070 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3071  const TypeConverter &typeConverter,
3072  ConversionPatternRewriter &rewriter) {
3073  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3074  if (!type)
3075  return failure();
3076 
3077  // Convert the original function types.
3078  TypeConverter::SignatureConversion result(type.getNumInputs());
3079  SmallVector<Type, 1> newResults;
3080  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3081  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3082  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3083  typeConverter, &result)))
3084  return failure();
3085 
3086  // Update the function signature in-place.
3087  auto newType = FunctionType::get(rewriter.getContext(),
3088  result.getConvertedTypes(), newResults);
3089 
3090  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3091 
3092  return success();
3093 }
3094 
3095 /// Create a default conversion pattern that rewrites the type signature of a
3096 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3097 /// represent their type.
3098 namespace {
3099 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3100  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3101  MLIRContext *ctx,
3102  const TypeConverter &converter)
3103  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3104 
3105  LogicalResult
3106  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3107  ConversionPatternRewriter &rewriter) const override {
3108  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3109  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3110  }
3111 };
3112 
3113 struct AnyFunctionOpInterfaceSignatureConversion
3114  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3116 
3117  LogicalResult
3118  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3119  ConversionPatternRewriter &rewriter) const override {
3120  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3121  }
3122 };
3123 } // namespace
3124 
3125 FailureOr<Operation *>
3127  const TypeConverter &converter,
3128  ConversionPatternRewriter &rewriter) {
3129  assert(op && "Invalid op");
3130  Location loc = op->getLoc();
3131  if (converter.isLegal(op))
3132  return rewriter.notifyMatchFailure(loc, "op already legal");
3133 
3134  OperationState newOp(loc, op->getName());
3135  newOp.addOperands(operands);
3136 
3137  SmallVector<Type> newResultTypes;
3138  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3139  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3140 
3141  newOp.addTypes(newResultTypes);
3142  newOp.addAttributes(op->getAttrs());
3143  return rewriter.create(newOp);
3144 }
3145 
3147  StringRef functionLikeOpName, RewritePatternSet &patterns,
3148  const TypeConverter &converter) {
3149  patterns.add<FunctionOpInterfaceSignatureConversion>(
3150  functionLikeOpName, patterns.getContext(), converter);
3151 }
3152 
3154  RewritePatternSet &patterns, const TypeConverter &converter) {
3155  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3156  converter, patterns.getContext());
3157 }
3158 
3159 //===----------------------------------------------------------------------===//
3160 // ConversionTarget
3161 //===----------------------------------------------------------------------===//
3162 
3164  LegalizationAction action) {
3165  legalOperations[op].action = action;
3166 }
3167 
3169  LegalizationAction action) {
3170  for (StringRef dialect : dialectNames)
3171  legalDialects[dialect] = action;
3172 }
3173 
3175  -> std::optional<LegalizationAction> {
3176  std::optional<LegalizationInfo> info = getOpInfo(op);
3177  return info ? info->action : std::optional<LegalizationAction>();
3178 }
3179 
3181  -> std::optional<LegalOpDetails> {
3182  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3183  if (!info)
3184  return std::nullopt;
3185 
3186  // Returns true if this operation instance is known to be legal.
3187  auto isOpLegal = [&] {
3188  // Handle dynamic legality either with the provided legality function.
3189  if (info->action == LegalizationAction::Dynamic) {
3190  std::optional<bool> result = info->legalityFn(op);
3191  if (result)
3192  return *result;
3193  }
3194 
3195  // Otherwise, the operation is only legal if it was marked 'Legal'.
3196  return info->action == LegalizationAction::Legal;
3197  };
3198  if (!isOpLegal())
3199  return std::nullopt;
3200 
3201  // This operation is legal, compute any additional legality information.
3202  LegalOpDetails legalityDetails;
3203  if (info->isRecursivelyLegal) {
3204  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3205  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3206  legalityDetails.isRecursivelyLegal =
3207  legalityFnIt->second(op).value_or(true);
3208  } else {
3209  legalityDetails.isRecursivelyLegal = true;
3210  }
3211  }
3212  return legalityDetails;
3213 }
3214 
3216  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3217  if (!info)
3218  return false;
3219 
3220  if (info->action == LegalizationAction::Dynamic) {
3221  std::optional<bool> result = info->legalityFn(op);
3222  if (!result)
3223  return false;
3224 
3225  return !(*result);
3226  }
3227 
3228  return info->action == LegalizationAction::Illegal;
3229 }
3230 
3234  if (!oldCallback)
3235  return newCallback;
3236 
3237  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3238  Operation *op) -> std::optional<bool> {
3239  if (std::optional<bool> result = newCl(op))
3240  return *result;
3241 
3242  return oldCl(op);
3243  };
3244  return chain;
3245 }
3246 
3247 void ConversionTarget::setLegalityCallback(
3248  OperationName name, const DynamicLegalityCallbackFn &callback) {
3249  assert(callback && "expected valid legality callback");
3250  auto *infoIt = legalOperations.find(name);
3251  assert(infoIt != legalOperations.end() &&
3252  infoIt->second.action == LegalizationAction::Dynamic &&
3253  "expected operation to already be marked as dynamically legal");
3254  infoIt->second.legalityFn =
3255  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3256 }
3257 
3259  OperationName name, const DynamicLegalityCallbackFn &callback) {
3260  auto *infoIt = legalOperations.find(name);
3261  assert(infoIt != legalOperations.end() &&
3262  infoIt->second.action != LegalizationAction::Illegal &&
3263  "expected operation to already be marked as legal");
3264  infoIt->second.isRecursivelyLegal = true;
3265  if (callback)
3266  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3267  std::move(opRecursiveLegalityFns[name]), callback);
3268  else
3269  opRecursiveLegalityFns.erase(name);
3270 }
3271 
3272 void ConversionTarget::setLegalityCallback(
3273  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3274  assert(callback && "expected valid legality callback");
3275  for (StringRef dialect : dialects)
3276  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3277  std::move(dialectLegalityFns[dialect]), callback);
3278 }
3279 
3280 void ConversionTarget::setLegalityCallback(
3281  const DynamicLegalityCallbackFn &callback) {
3282  assert(callback && "expected valid legality callback");
3283  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3284 }
3285 
3286 auto ConversionTarget::getOpInfo(OperationName op) const
3287  -> std::optional<LegalizationInfo> {
3288  // Check for info for this specific operation.
3289  const auto *it = legalOperations.find(op);
3290  if (it != legalOperations.end())
3291  return it->second;
3292  // Check for info for the parent dialect.
3293  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3294  if (dialectIt != legalDialects.end()) {
3295  DynamicLegalityCallbackFn callback;
3296  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3297  if (dialectFn != dialectLegalityFns.end())
3298  callback = dialectFn->second;
3299  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3300  callback};
3301  }
3302  // Otherwise, check if we mark unknown operations as dynamic.
3303  if (unknownLegalityFn)
3304  return LegalizationInfo{LegalizationAction::Dynamic,
3305  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3306  return std::nullopt;
3307 }
3308 
3309 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3310 //===----------------------------------------------------------------------===//
3311 // PDL Configuration
3312 //===----------------------------------------------------------------------===//
3313 
3315  auto &rewriterImpl =
3316  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3317  rewriterImpl.currentTypeConverter = getTypeConverter();
3318 }
3319 
3321  auto &rewriterImpl =
3322  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3323  rewriterImpl.currentTypeConverter = nullptr;
3324 }
3325 
3326 /// Remap the given value using the rewriter and the type converter in the
3327 /// provided config.
3328 static FailureOr<SmallVector<Value>>
3330  SmallVector<Value> mappedValues;
3331  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3332  return failure();
3333  return std::move(mappedValues);
3334 }
3335 
3337  patterns.getPDLPatterns().registerRewriteFunction(
3338  "convertValue",
3339  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3340  auto results = pdllConvertValues(
3341  static_cast<ConversionPatternRewriter &>(rewriter), value);
3342  if (failed(results))
3343  return failure();
3344  return results->front();
3345  });
3346  patterns.getPDLPatterns().registerRewriteFunction(
3347  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3348  return pdllConvertValues(
3349  static_cast<ConversionPatternRewriter &>(rewriter), values);
3350  });
3351  patterns.getPDLPatterns().registerRewriteFunction(
3352  "convertType",
3353  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3354  auto &rewriterImpl =
3355  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3356  if (const TypeConverter *converter =
3357  rewriterImpl.currentTypeConverter) {
3358  if (Type newType = converter->convertType(type))
3359  return newType;
3360  return failure();
3361  }
3362  return type;
3363  });
3364  patterns.getPDLPatterns().registerRewriteFunction(
3365  "convertTypes",
3366  [](PatternRewriter &rewriter,
3367  TypeRange types) -> FailureOr<SmallVector<Type>> {
3368  auto &rewriterImpl =
3369  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3370  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3371  if (!converter)
3372  return SmallVector<Type>(types);
3373 
3374  SmallVector<Type> remappedTypes;
3375  if (failed(converter->convertTypes(types, remappedTypes)))
3376  return failure();
3377  return std::move(remappedTypes);
3378  });
3379 }
3380 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3381 
3382 //===----------------------------------------------------------------------===//
3383 // Op Conversion Entry Points
3384 //===----------------------------------------------------------------------===//
3385 
3386 //===----------------------------------------------------------------------===//
3387 // Partial Conversion
3388 //===----------------------------------------------------------------------===//
3389 
3391  ArrayRef<Operation *> ops, const ConversionTarget &target,
3393  OperationConverter opConverter(target, patterns, config,
3394  OpConversionMode::Partial);
3395  return opConverter.convertOperations(ops);
3396 }
3397 LogicalResult
3401  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3402 }
3403 
3404 //===----------------------------------------------------------------------===//
3405 // Full Conversion
3406 //===----------------------------------------------------------------------===//
3407 
3409  const ConversionTarget &target,
3412  OperationConverter opConverter(target, patterns, config,
3413  OpConversionMode::Full);
3414  return opConverter.convertOperations(ops);
3415 }
3417  const ConversionTarget &target,
3420  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3421 }
3422 
3423 //===----------------------------------------------------------------------===//
3424 // Analysis Conversion
3425 //===----------------------------------------------------------------------===//
3426 
3427 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3428 /// op is a top-level module op (which is expected to be isolated from above),
3429 /// return that op.
3431  // Check if there is a top-level operation within `ops`. If so, return that
3432  // op.
3433  for (Operation *op : ops) {
3434  if (!op->getParentOp()) {
3435 #ifndef NDEBUG
3436  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3437  "expected top-level op to be isolated from above");
3438  for (Operation *other : ops)
3439  assert(op->isAncestor(other) &&
3440  "expected ops to have a common ancestor");
3441 #endif // NDEBUG
3442  return op;
3443  }
3444  }
3445 
3446  // No top-level op. Find a common ancestor.
3447  Operation *commonAncestor =
3448  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3449  for (Operation *op : ops.drop_front()) {
3450  while (!commonAncestor->isProperAncestor(op)) {
3451  commonAncestor =
3453  assert(commonAncestor &&
3454  "expected to find a common isolated from above ancestor");
3455  }
3456  }
3457 
3458  return commonAncestor;
3459 }
3460 
3464 #ifndef NDEBUG
3465  if (config.legalizableOps)
3466  assert(config.legalizableOps->empty() && "expected empty set");
3467 #endif // NDEBUG
3468 
3469  // Clone closted common ancestor that is isolated from above.
3470  Operation *commonAncestor = findCommonAncestor(ops);
3471  IRMapping mapping;
3472  Operation *clonedAncestor = commonAncestor->clone(mapping);
3473  // Compute inverse IR mapping.
3474  DenseMap<Operation *, Operation *> inverseOperationMap;
3475  for (auto &it : mapping.getOperationMap())
3476  inverseOperationMap[it.second] = it.first;
3477 
3478  // Convert the cloned operations. The original IR will remain unchanged.
3479  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3480  ops, [&](Operation *op) { return mapping.lookup(op); });
3481  OperationConverter opConverter(target, patterns, config,
3482  OpConversionMode::Analysis);
3483  LogicalResult status = opConverter.convertOperations(opsToConvert);
3484 
3485  // Remap `legalizableOps`, so that they point to the original ops and not the
3486  // cloned ops.
3487  if (config.legalizableOps) {
3488  DenseSet<Operation *> originalLegalizableOps;
3489  for (Operation *op : *config.legalizableOps)
3490  originalLegalizableOps.insert(inverseOperationMap[op]);
3491  *config.legalizableOps = std::move(originalLegalizableOps);
3492  }
3493 
3494  // Erase the cloned IR.
3495  clonedAncestor->erase();
3496  return status;
3497 }
3498 
3499 LogicalResult
3503  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3504 }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
union mlir::linalg::@1205::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:68
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:96
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={}) override
PatternRewriter hook for inlining the ops of a block into another block.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void dropAllUses()
Drop all uses of this object from their respective owners.
Definition: UseDefLists.h:202
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:729
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:32
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:325
Block::iterator getPoint() const
Definition: Builders.h:338
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:335
Block * getBlock() const
Definition: Builders.h:337
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:428
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:471
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getTypes() const
Definition: ValueRange.cpp:28
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:296
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:719
operand_iterator operand_begin()
Definition: Operation.h:374
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_iterator operand_end()
Definition: Operation.h:375
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:607
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
result_range getResults()
Definition: Operation.h:415
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:896
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
succ_iterator successor_end()
Definition: Operation.h:702
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
succ_iterator successor_begin()
Definition: Operation.h:701
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:601
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a range of values that remaps an existing signature in...
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp=nullptr)
Build an unresolved materialization operation given a range of output types and a list of input opera...
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
void notifyOpReplaced(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Notifies that an op is about to be replaced with the given values.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.