MLIR  21.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Iterators.h"
19 #include "llvm/ADT/ScopeExit.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/SaveAndRestore.h"
25 #include "llvm/Support/ScopedPrinter.h"
26 #include <optional>
27 
28 using namespace mlir;
29 using namespace mlir::detail;
30 
31 #define DEBUG_TYPE "dialect-conversion"
32 
33 /// A utility function to log a successful result for the given reason.
34 template <typename... Args>
35 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
36  LLVM_DEBUG({
37  os.unindent();
38  os.startLine() << "} -> SUCCESS";
39  if (!fmt.empty())
40  os.getOStream() << " : "
41  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
42  os.getOStream() << "\n";
43  });
44 }
45 
46 /// A utility function to log a failure result for the given reason.
47 template <typename... Args>
48 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
49  LLVM_DEBUG({
50  os.unindent();
51  os.startLine() << "} -> FAILURE : "
52  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
53  << "\n";
54  });
55 }
56 
57 /// Helper function that computes an insertion point where the given value is
58 /// defined and can be used without a dominance violation.
60  Block *insertBlock = value.getParentBlock();
61  Block::iterator insertPt = insertBlock->begin();
62  if (OpResult inputRes = dyn_cast<OpResult>(value))
63  insertPt = ++inputRes.getOwner()->getIterator();
64  return OpBuilder::InsertPoint(insertBlock, insertPt);
65 }
66 
67 /// Helper function that computes an insertion point where the given values are
68 /// defined and can be used without a dominance violation.
70  assert(!vals.empty() && "expected at least one value");
71  DominanceInfo domInfo;
72  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
73  for (Value v : vals.drop_front()) {
74  // Choose the "later" insertion point.
76  if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
77  nextPt.getPoint())) {
78  // pt is before nextPt => choose nextPt.
79  pt = nextPt;
80  } else {
81 #ifndef NDEBUG
82  // nextPt should be before pt => choose pt.
83  // If pt, nextPt are no dominance relationship, then there is no valid
84  // insertion point at which all given values are defined.
85  bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
86  pt.getBlock(), pt.getPoint());
87  assert(dom && "unable to find valid insertion point");
88 #endif // NDEBUG
89  }
90  }
91  return pt;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // ConversionValueMapping
96 //===----------------------------------------------------------------------===//
97 
98 /// A vector of SSA values, optimized for the most common case of a single
99 /// value.
100 using ValueVector = SmallVector<Value, 1>;
101 
102 namespace {
103 
104 /// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
105 struct ValueVectorMapInfo {
106  static ValueVector getEmptyKey() { return ValueVector{Value()}; }
107  static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
108  static ::llvm::hash_code getHashValue(const ValueVector &val) {
109  return ::llvm::hash_combine_range(val);
110  }
111  static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
112  return LHS == RHS;
113  }
114 };
115 
116 /// This class wraps a IRMapping to provide recursive lookup
117 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
118 struct ConversionValueMapping {
119  /// Return "true" if an SSA value is mapped to the given value. May return
120  /// false positives.
121  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
122 
123  /// Lookup the most recently mapped values with the desired types in the
124  /// mapping.
125  ///
126  /// Special cases:
127  /// - If the desired type range is empty, simply return the most recently
128  /// mapped values.
129  /// - If there is no mapping to the desired types, also return the most
130  /// recently mapped values.
131  /// - If there is no mapping for the given values at all, return the given
132  /// value.
133  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
134 
135  /// Lookup the given value within the map, or return an empty vector if the
136  /// value is not mapped. If it is mapped, this follows the same behavior
137  /// as `lookupOrDefault`.
138  ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
139 
140  template <typename T>
141  struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
142 
143  /// Map a value vector to the one provided.
144  template <typename OldVal, typename NewVal>
145  std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
146  map(OldVal &&oldVal, NewVal &&newVal) {
147  LLVM_DEBUG({
148  ValueVector next(newVal);
149  while (true) {
150  assert(next != oldVal && "inserting cyclic mapping");
151  auto it = mapping.find(next);
152  if (it == mapping.end())
153  break;
154  next = it->second;
155  }
156  });
157  mappedTo.insert_range(newVal);
158 
159  mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
160  }
161 
162  /// Map a value vector or single value to the one provided.
163  template <typename OldVal, typename NewVal>
164  std::enable_if_t<!IsValueVector<OldVal>::value ||
165  !IsValueVector<NewVal>::value>
166  map(OldVal &&oldVal, NewVal &&newVal) {
167  if constexpr (IsValueVector<OldVal>{}) {
168  map(std::forward<OldVal>(oldVal), ValueVector{newVal});
169  } else if constexpr (IsValueVector<NewVal>{}) {
170  map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
171  } else {
172  map(ValueVector{oldVal}, ValueVector{newVal});
173  }
174  }
175 
176  void map(Value oldVal, SmallVector<Value> &&newVal) {
177  map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
178  }
179 
180  /// Drop the last mapping for the given values.
181  void erase(const ValueVector &value) { mapping.erase(value); }
182 
183 private:
184  /// Current value mappings.
186 
187  /// All SSA values that are mapped to. May contain false positives.
188  DenseSet<Value> mappedTo;
189 };
190 } // namespace
191 
193 ConversionValueMapping::lookupOrDefault(Value from,
194  TypeRange desiredTypes) const {
195  // Try to find the deepest values that have the desired types. If there is no
196  // such mapping, simply return the deepest values.
197  ValueVector desiredValue;
198  ValueVector current{from};
199  do {
200  // Store the current value if the types match.
201  if (TypeRange(ValueRange(current)) == desiredTypes)
202  desiredValue = current;
203 
204  // If possible, Replace each value with (one or multiple) mapped values.
205  ValueVector next;
206  for (Value v : current) {
207  auto it = mapping.find({v});
208  if (it != mapping.end()) {
209  llvm::append_range(next, it->second);
210  } else {
211  next.push_back(v);
212  }
213  }
214  if (next != current) {
215  // If at least one value was replaced, continue the lookup from there.
216  current = std::move(next);
217  continue;
218  }
219 
220  // Otherwise: Check if there is a mapping for the entire vector. Such
221  // mappings are materializations. (N:M mapping are not supported for value
222  // replacements.)
223  //
224  // Note: From a correctness point of view, materializations do not have to
225  // be stored (and looked up) in the mapping. But for performance reasons,
226  // we choose to reuse existing IR (when possible) instead of creating it
227  // multiple times.
228  auto it = mapping.find(current);
229  if (it == mapping.end()) {
230  // No mapping found: The lookup stops here.
231  break;
232  }
233  current = it->second;
234  } while (true);
235 
236  // If the desired values were found use them, otherwise default to the leaf
237  // values.
238  // Note: If `desiredTypes` is empty, this function always returns `current`.
239  return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
240 }
241 
242 ValueVector ConversionValueMapping::lookupOrNull(Value from,
243  TypeRange desiredTypes) const {
244  ValueVector result = lookupOrDefault(from, desiredTypes);
245  if (result == ValueVector{from} ||
246  (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
247  return {};
248  return result;
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // Rewriter and Translation State
253 //===----------------------------------------------------------------------===//
254 namespace {
255 /// This class contains a snapshot of the current conversion rewriter state.
256 /// This is useful when saving and undoing a set of rewrites.
257 struct RewriterState {
258  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
259  unsigned numReplacedOps)
260  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
261  numReplacedOps(numReplacedOps) {}
262 
263  /// The current number of rewrites performed.
264  unsigned numRewrites;
265 
266  /// The current number of ignored operations.
267  unsigned numIgnoredOperations;
268 
269  /// The current number of replaced ops that are scheduled for erasure.
270  unsigned numReplacedOps;
271 };
272 
273 //===----------------------------------------------------------------------===//
274 // IR rewrites
275 //===----------------------------------------------------------------------===//
276 
277 /// An IR rewrite that can be committed (upon success) or rolled back (upon
278 /// failure).
279 ///
280 /// The dialect conversion keeps track of IR modifications (requested by the
281 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
282 /// are directly applied to the IR as the rewriter API is used, some are applied
283 /// partially, and some are delayed until the `IRRewrite` objects are committed.
284 class IRRewrite {
285 public:
286  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
287  /// Enum values are ordered, so that they can be used in `classof`: first all
288  /// block rewrites, then all operation rewrites.
289  enum class Kind {
290  // Block rewrites
291  CreateBlock,
292  EraseBlock,
293  InlineBlock,
294  MoveBlock,
295  BlockTypeConversion,
296  ReplaceBlockArg,
297  // Operation rewrites
298  MoveOperation,
299  ModifyOperation,
300  ReplaceOperation,
301  CreateOperation,
302  UnresolvedMaterialization
303  };
304 
305  virtual ~IRRewrite() = default;
306 
307  /// Roll back the rewrite. Operations may be erased during rollback.
308  virtual void rollback() = 0;
309 
310  /// Commit the rewrite. At this point, it is certain that the dialect
311  /// conversion will succeed. All IR modifications, except for operation/block
312  /// erasure, must be performed through the given rewriter.
313  ///
314  /// Instead of erasing operations/blocks, they should merely be unlinked
315  /// commit phase and finally be erased during the cleanup phase. This is
316  /// because internal dialect conversion state (such as `mapping`) may still
317  /// be using them.
318  ///
319  /// Any IR modification that was already performed before the commit phase
320  /// (e.g., insertion of an op) must be communicated to the listener that may
321  /// be attached to the given rewriter.
322  virtual void commit(RewriterBase &rewriter) {}
323 
324  /// Cleanup operations/blocks. Cleanup is called after commit.
325  virtual void cleanup(RewriterBase &rewriter) {}
326 
327  Kind getKind() const { return kind; }
328 
329  static bool classof(const IRRewrite *rewrite) { return true; }
330 
331 protected:
332  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
333  : kind(kind), rewriterImpl(rewriterImpl) {}
334 
335  const ConversionConfig &getConfig() const;
336 
337  const Kind kind;
338  ConversionPatternRewriterImpl &rewriterImpl;
339 };
340 
341 /// A block rewrite.
342 class BlockRewrite : public IRRewrite {
343 public:
344  /// Return the block that this rewrite operates on.
345  Block *getBlock() const { return block; }
346 
347  static bool classof(const IRRewrite *rewrite) {
348  return rewrite->getKind() >= Kind::CreateBlock &&
349  rewrite->getKind() <= Kind::ReplaceBlockArg;
350  }
351 
352 protected:
353  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
354  Block *block)
355  : IRRewrite(kind, rewriterImpl), block(block) {}
356 
357  // The block that this rewrite operates on.
358  Block *block;
359 };
360 
361 /// Creation of a block. Block creations are immediately reflected in the IR.
362 /// There is no extra work to commit the rewrite. During rollback, the newly
363 /// created block is erased.
364 class CreateBlockRewrite : public BlockRewrite {
365 public:
366  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
367  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
368 
369  static bool classof(const IRRewrite *rewrite) {
370  return rewrite->getKind() == Kind::CreateBlock;
371  }
372 
373  void commit(RewriterBase &rewriter) override {
374  // The block was already created and inserted. Just inform the listener.
375  if (auto *listener = rewriter.getListener())
376  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
377  }
378 
379  void rollback() override {
380  // Unlink all of the operations within this block, they will be deleted
381  // separately.
382  auto &blockOps = block->getOperations();
383  while (!blockOps.empty())
384  blockOps.remove(blockOps.begin());
385  block->dropAllUses();
386  if (block->getParent())
387  block->erase();
388  else
389  delete block;
390  }
391 };
392 
393 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
394 /// blocks are immediately unlinked, but only erased during cleanup. This makes
395 /// it easier to rollback a block erasure: the block is simply inserted into its
396 /// original location.
397 class EraseBlockRewrite : public BlockRewrite {
398 public:
399  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
400  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
401  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
402 
403  static bool classof(const IRRewrite *rewrite) {
404  return rewrite->getKind() == Kind::EraseBlock;
405  }
406 
407  ~EraseBlockRewrite() override {
408  assert(!block &&
409  "rewrite was neither rolled back nor committed/cleaned up");
410  }
411 
412  void rollback() override {
413  // The block (owned by this rewrite) was not actually erased yet. It was
414  // just unlinked. Put it back into its original position.
415  assert(block && "expected block");
416  auto &blockList = region->getBlocks();
417  Region::iterator before = insertBeforeBlock
418  ? Region::iterator(insertBeforeBlock)
419  : blockList.end();
420  blockList.insert(before, block);
421  block = nullptr;
422  }
423 
424  void commit(RewriterBase &rewriter) override {
425  // Erase the block.
426  assert(block && "expected block");
427  assert(block->empty() && "expected empty block");
428 
429  // Notify the listener that the block is about to be erased.
430  if (auto *listener =
431  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
432  listener->notifyBlockErased(block);
433  }
434 
435  void cleanup(RewriterBase &rewriter) override {
436  // Erase the block.
437  block->dropAllDefinedValueUses();
438  delete block;
439  block = nullptr;
440  }
441 
442 private:
443  // The region in which this block was previously contained.
444  Region *region;
445 
446  // The original successor of this block before it was unlinked. "nullptr" if
447  // this block was the only block in the region.
448  Block *insertBeforeBlock;
449 };
450 
451 /// Inlining of a block. This rewrite is immediately reflected in the IR.
452 /// Note: This rewrite represents only the inlining of the operations. The
453 /// erasure of the inlined block is a separate rewrite.
454 class InlineBlockRewrite : public BlockRewrite {
455 public:
456  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
457  Block *sourceBlock, Block::iterator before)
458  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
459  sourceBlock(sourceBlock),
460  firstInlinedInst(sourceBlock->empty() ? nullptr
461  : &sourceBlock->front()),
462  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
463  // If a listener is attached to the dialect conversion, ops must be moved
464  // one-by-one. When they are moved in bulk, notifications cannot be sent
465  // because the ops that used to be in the source block at the time of the
466  // inlining (before the "commit" phase) are unknown at the time when
467  // notifications are sent (which is during the "commit" phase).
468  assert(!getConfig().listener &&
469  "InlineBlockRewrite not supported if listener is attached");
470  }
471 
472  static bool classof(const IRRewrite *rewrite) {
473  return rewrite->getKind() == Kind::InlineBlock;
474  }
475 
476  void rollback() override {
477  // Put the operations from the destination block (owned by the rewrite)
478  // back into the source block.
479  if (firstInlinedInst) {
480  assert(lastInlinedInst && "expected operation");
481  sourceBlock->getOperations().splice(sourceBlock->begin(),
482  block->getOperations(),
483  Block::iterator(firstInlinedInst),
484  ++Block::iterator(lastInlinedInst));
485  }
486  }
487 
488 private:
489  // The block that originally contained the operations.
490  Block *sourceBlock;
491 
492  // The first inlined operation.
493  Operation *firstInlinedInst;
494 
495  // The last inlined operation.
496  Operation *lastInlinedInst;
497 };
498 
499 /// Moving of a block. This rewrite is immediately reflected in the IR.
500 class MoveBlockRewrite : public BlockRewrite {
501 public:
502  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
503  Region *region, Block *insertBeforeBlock)
504  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
505  insertBeforeBlock(insertBeforeBlock) {}
506 
507  static bool classof(const IRRewrite *rewrite) {
508  return rewrite->getKind() == Kind::MoveBlock;
509  }
510 
511  void commit(RewriterBase &rewriter) override {
512  // The block was already moved. Just inform the listener.
513  if (auto *listener = rewriter.getListener()) {
514  // Note: `previousIt` cannot be passed because this is a delayed
515  // notification and iterators into past IR state cannot be represented.
516  listener->notifyBlockInserted(block, /*previous=*/region,
517  /*previousIt=*/{});
518  }
519  }
520 
521  void rollback() override {
522  // Move the block back to its original position.
523  Region::iterator before =
524  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
525  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
526  }
527 
528 private:
529  // The region in which this block was previously contained.
530  Region *region;
531 
532  // The original successor of this block before it was moved. "nullptr" if
533  // this block was the only block in the region.
534  Block *insertBeforeBlock;
535 };
536 
537 /// Block type conversion. This rewrite is partially reflected in the IR.
538 class BlockTypeConversionRewrite : public BlockRewrite {
539 public:
540  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
541  Block *origBlock, Block *newBlock)
542  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
543  newBlock(newBlock) {}
544 
545  static bool classof(const IRRewrite *rewrite) {
546  return rewrite->getKind() == Kind::BlockTypeConversion;
547  }
548 
549  Block *getOrigBlock() const { return block; }
550 
551  Block *getNewBlock() const { return newBlock; }
552 
553  void commit(RewriterBase &rewriter) override;
554 
555  void rollback() override;
556 
557 private:
558  /// The new block that was created as part of this signature conversion.
559  Block *newBlock;
560 };
561 
562 /// Replacing a block argument. This rewrite is not immediately reflected in the
563 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
564 /// until the rewrite is committed.
565 class ReplaceBlockArgRewrite : public BlockRewrite {
566 public:
567  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
568  Block *block, BlockArgument arg,
569  const TypeConverter *converter)
570  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
571  converter(converter) {}
572 
573  static bool classof(const IRRewrite *rewrite) {
574  return rewrite->getKind() == Kind::ReplaceBlockArg;
575  }
576 
577  void commit(RewriterBase &rewriter) override;
578 
579  void rollback() override;
580 
581 private:
582  BlockArgument arg;
583 
584  /// The current type converter when the block argument was replaced.
585  const TypeConverter *converter;
586 };
587 
588 /// An operation rewrite.
589 class OperationRewrite : public IRRewrite {
590 public:
591  /// Return the operation that this rewrite operates on.
592  Operation *getOperation() const { return op; }
593 
594  static bool classof(const IRRewrite *rewrite) {
595  return rewrite->getKind() >= Kind::MoveOperation &&
596  rewrite->getKind() <= Kind::UnresolvedMaterialization;
597  }
598 
599 protected:
600  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
601  Operation *op)
602  : IRRewrite(kind, rewriterImpl), op(op) {}
603 
604  // The operation that this rewrite operates on.
605  Operation *op;
606 };
607 
608 /// Moving of an operation. This rewrite is immediately reflected in the IR.
609 class MoveOperationRewrite : public OperationRewrite {
610 public:
611  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
612  Operation *op, Block *block, Operation *insertBeforeOp)
613  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
614  insertBeforeOp(insertBeforeOp) {}
615 
616  static bool classof(const IRRewrite *rewrite) {
617  return rewrite->getKind() == Kind::MoveOperation;
618  }
619 
620  void commit(RewriterBase &rewriter) override {
621  // The operation was already moved. Just inform the listener.
622  if (auto *listener = rewriter.getListener()) {
623  // Note: `previousIt` cannot be passed because this is a delayed
624  // notification and iterators into past IR state cannot be represented.
625  listener->notifyOperationInserted(
626  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
627  /*insertPt=*/{}));
628  }
629  }
630 
631  void rollback() override {
632  // Move the operation back to its original position.
633  Block::iterator before =
634  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
635  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
636  }
637 
638 private:
639  // The block in which this operation was previously contained.
640  Block *block;
641 
642  // The original successor of this operation before it was moved. "nullptr"
643  // if this operation was the only operation in the region.
644  Operation *insertBeforeOp;
645 };
646 
647 /// In-place modification of an op. This rewrite is immediately reflected in
648 /// the IR. The previous state of the operation is stored in this object.
649 class ModifyOperationRewrite : public OperationRewrite {
650 public:
651  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
652  Operation *op)
653  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
654  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
655  operands(op->operand_begin(), op->operand_end()),
656  successors(op->successor_begin(), op->successor_end()) {
657  if (OpaqueProperties prop = op->getPropertiesStorage()) {
658  // Make a copy of the properties.
659  propertiesStorage = operator new(op->getPropertiesStorageSize());
660  OpaqueProperties propCopy(propertiesStorage);
661  name.initOpProperties(propCopy, /*init=*/prop);
662  }
663  }
664 
665  static bool classof(const IRRewrite *rewrite) {
666  return rewrite->getKind() == Kind::ModifyOperation;
667  }
668 
669  ~ModifyOperationRewrite() override {
670  assert(!propertiesStorage &&
671  "rewrite was neither committed nor rolled back");
672  }
673 
674  void commit(RewriterBase &rewriter) override {
675  // Notify the listener that the operation was modified in-place.
676  if (auto *listener =
677  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
678  listener->notifyOperationModified(op);
679 
680  if (propertiesStorage) {
681  OpaqueProperties propCopy(propertiesStorage);
682  // Note: The operation may have been erased in the mean time, so
683  // OperationName must be stored in this object.
684  name.destroyOpProperties(propCopy);
685  operator delete(propertiesStorage);
686  propertiesStorage = nullptr;
687  }
688  }
689 
690  void rollback() override {
691  op->setLoc(loc);
692  op->setAttrs(attrs);
693  op->setOperands(operands);
694  for (const auto &it : llvm::enumerate(successors))
695  op->setSuccessor(it.value(), it.index());
696  if (propertiesStorage) {
697  OpaqueProperties propCopy(propertiesStorage);
698  op->copyProperties(propCopy);
699  name.destroyOpProperties(propCopy);
700  operator delete(propertiesStorage);
701  propertiesStorage = nullptr;
702  }
703  }
704 
705 private:
706  OperationName name;
707  LocationAttr loc;
708  DictionaryAttr attrs;
709  SmallVector<Value, 8> operands;
710  SmallVector<Block *, 2> successors;
711  void *propertiesStorage = nullptr;
712 };
713 
714 /// Replacing an operation. Erasing an operation is treated as a special case
715 /// with "null" replacements. This rewrite is not immediately reflected in the
716 /// IR. An internal IR mapping is updated, but values are not replaced and the
717 /// original op is not erased until the rewrite is committed.
718 class ReplaceOperationRewrite : public OperationRewrite {
719 public:
720  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
721  Operation *op, const TypeConverter *converter)
722  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
723  converter(converter) {}
724 
725  static bool classof(const IRRewrite *rewrite) {
726  return rewrite->getKind() == Kind::ReplaceOperation;
727  }
728 
729  void commit(RewriterBase &rewriter) override;
730 
731  void rollback() override;
732 
733  void cleanup(RewriterBase &rewriter) override;
734 
735 private:
736  /// An optional type converter that can be used to materialize conversions
737  /// between the new and old values if necessary.
738  const TypeConverter *converter;
739 };
740 
741 class CreateOperationRewrite : public OperationRewrite {
742 public:
743  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
744  Operation *op)
745  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
746 
747  static bool classof(const IRRewrite *rewrite) {
748  return rewrite->getKind() == Kind::CreateOperation;
749  }
750 
751  void commit(RewriterBase &rewriter) override {
752  // The operation was already created and inserted. Just inform the listener.
753  if (auto *listener = rewriter.getListener())
754  listener->notifyOperationInserted(op, /*previous=*/{});
755  }
756 
757  void rollback() override;
758 };
759 
760 /// The type of materialization.
761 enum MaterializationKind {
762  /// This materialization materializes a conversion from an illegal type to a
763  /// legal one.
764  Target,
765 
766  /// This materialization materializes a conversion from a legal type back to
767  /// an illegal one.
768  Source
769 };
770 
771 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
772 /// op. Unresolved materializations are erased at the end of the dialect
773 /// conversion.
774 class UnresolvedMaterializationRewrite : public OperationRewrite {
775 public:
776  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
777  UnrealizedConversionCastOp op,
778  const TypeConverter *converter,
779  MaterializationKind kind, Type originalType,
780  ValueVector mappedValues);
781 
782  static bool classof(const IRRewrite *rewrite) {
783  return rewrite->getKind() == Kind::UnresolvedMaterialization;
784  }
785 
786  void rollback() override;
787 
788  UnrealizedConversionCastOp getOperation() const {
789  return cast<UnrealizedConversionCastOp>(op);
790  }
791 
792  /// Return the type converter of this materialization (which may be null).
793  const TypeConverter *getConverter() const {
794  return converterAndKind.getPointer();
795  }
796 
797  /// Return the kind of this materialization.
798  MaterializationKind getMaterializationKind() const {
799  return converterAndKind.getInt();
800  }
801 
802  /// Return the original type of the SSA value.
803  Type getOriginalType() const { return originalType; }
804 
805 private:
806  /// The corresponding type converter to use when resolving this
807  /// materialization, and the kind of this materialization.
808  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
809  converterAndKind;
810 
811  /// The original type of the SSA value. Only used for target
812  /// materializations.
813  Type originalType;
814 
815  /// The values in the conversion value mapping that are being replaced by the
816  /// results of this unresolved materialization.
817  ValueVector mappedValues;
818 };
819 } // namespace
820 
821 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
822 /// Return "true" if there is an operation rewrite that matches the specified
823 /// rewrite type and operation among the given rewrites.
824 template <typename RewriteTy, typename R>
825 static bool hasRewrite(R &&rewrites, Operation *op) {
826  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
827  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
828  return rewriteTy && rewriteTy->getOperation() == op;
829  });
830 }
831 
832 /// Return "true" if there is a block rewrite that matches the specified
833 /// rewrite type and block among the given rewrites.
834 template <typename RewriteTy, typename R>
835 static bool hasRewrite(R &&rewrites, Block *block) {
836  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
837  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
838  return rewriteTy && rewriteTy->getBlock() == block;
839  });
840 }
841 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
842 
843 //===----------------------------------------------------------------------===//
844 // ConversionPatternRewriterImpl
845 //===----------------------------------------------------------------------===//
846 namespace mlir {
847 namespace detail {
850  const ConversionConfig &config)
851  : context(ctx), eraseRewriter(ctx), config(config) {}
852 
853  //===--------------------------------------------------------------------===//
854  // State Management
855  //===--------------------------------------------------------------------===//
856 
857  /// Return the current state of the rewriter.
858  RewriterState getCurrentState();
859 
860  /// Apply all requested operation rewrites. This method is invoked when the
861  /// conversion process succeeds.
862  void applyRewrites();
863 
864  /// Reset the state of the rewriter to a previously saved point. Optionally,
865  /// the name of the pattern that triggered the rollback can specified for
866  /// debugging purposes.
867  void resetState(RewriterState state, StringRef patternName = "");
868 
869  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
870  /// failure.
871  template <typename RewriteTy, typename... Args>
872  void appendRewrite(Args &&...args) {
873  rewrites.push_back(
874  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
875  }
876 
877  /// Undo the rewrites (motions, splits) one by one in reverse order until
878  /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
879  /// that triggered the rollback can specified for debugging purposes.
880  void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
881 
882  /// Remap the given values to those with potentially different types. Returns
883  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
884  /// is the tag used when describing a value within a diagnostic, e.g.
885  /// "operand".
886  LogicalResult remapValues(StringRef valueDiagTag,
887  std::optional<Location> inputLoc,
888  PatternRewriter &rewriter, ValueRange values,
889  SmallVector<ValueVector> &remapped);
890 
891  /// Return "true" if the given operation is ignored, and does not need to be
892  /// converted.
893  bool isOpIgnored(Operation *op) const;
894 
895  /// Return "true" if the given operation was replaced or erased.
896  bool wasOpReplaced(Operation *op) const;
897 
898  //===--------------------------------------------------------------------===//
899  // Type Conversion
900  //===--------------------------------------------------------------------===//
901 
902  /// Convert the types of block arguments within the given region.
903  FailureOr<Block *>
904  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
905  const TypeConverter &converter,
906  TypeConverter::SignatureConversion *entryConversion);
907 
908  /// Apply the given signature conversion on the given block. The new block
909  /// containing the updated signature is returned. If no conversions were
910  /// necessary, e.g. if the block has no arguments, `block` is returned.
911  /// `converter` is used to generate any necessary cast operations that
912  /// translate between the origin argument types and those specified in the
913  /// signature conversion.
914  Block *applySignatureConversion(
915  ConversionPatternRewriter &rewriter, Block *block,
916  const TypeConverter *converter,
917  TypeConverter::SignatureConversion &signatureConversion);
918 
919  //===--------------------------------------------------------------------===//
920  // Materializations
921  //===--------------------------------------------------------------------===//
922 
923  /// Build an unresolved materialization operation given a range of output
924  /// types and a list of input operands. Returns the inputs if they their
925  /// types match the output types.
926  ///
927  /// If a cast op was built, it can optionally be returned with the `castOp`
928  /// output argument.
929  ///
930  /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
931  /// the results of the unresolved materialization in the conversion value
932  /// mapping.
933  ValueRange buildUnresolvedMaterialization(
934  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
935  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
936  Type originalType, const TypeConverter *converter,
937  UnrealizedConversionCastOp *castOp = nullptr);
938 
939  /// Find a replacement value for the given SSA value in the conversion value
940  /// mapping. The replacement value must have the same type as the given SSA
941  /// value. If there is no replacement value with the correct type, find the
942  /// latest replacement value (regardless of the type) and build a source
943  /// materialization.
944  Value findOrBuildReplacementValue(Value value,
945  const TypeConverter *converter);
946 
947  //===--------------------------------------------------------------------===//
948  // Rewriter Notification Hooks
949  //===--------------------------------------------------------------------===//
950 
951  //// Notifies that an op was inserted.
952  void notifyOperationInserted(Operation *op,
953  OpBuilder::InsertPoint previous) override;
954 
955  /// Notifies that an op is about to be replaced with the given values.
956  void notifyOpReplaced(Operation *op,
957  SmallVector<SmallVector<Value>> &&newValues);
958 
959  /// Notifies that a block is about to be erased.
960  void notifyBlockIsBeingErased(Block *block);
961 
962  /// Notifies that a block was inserted.
963  void notifyBlockInserted(Block *block, Region *previous,
964  Region::iterator previousIt) override;
965 
966  /// Notifies that a block is being inlined into another block.
967  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
968  Block::iterator before);
969 
970  /// Notifies that a pattern match failed for the given reason.
971  void
972  notifyMatchFailure(Location loc,
973  function_ref<void(Diagnostic &)> reasonCallback) override;
974 
975  //===--------------------------------------------------------------------===//
976  // IR Erasure
977  //===--------------------------------------------------------------------===//
978 
979  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
980  /// operation or block is erased multiple times. This rewriter assumes that
981  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
983  public:
985  : RewriterBase(context, /*listener=*/this) {}
986 
987  /// Erase the given op (unless it was already erased).
988  void eraseOp(Operation *op) override {
989  if (wasErased(op))
990  return;
991  op->dropAllUses();
993  }
994 
995  /// Erase the given block (unless it was already erased).
996  void eraseBlock(Block *block) override {
997  if (wasErased(block))
998  return;
999  assert(block->empty() && "expected empty block");
1000  block->dropAllDefinedValueUses();
1001  RewriterBase::eraseBlock(block);
1002  }
1003 
1004  bool wasErased(void *ptr) const { return erased.contains(ptr); }
1005 
1006  void notifyOperationErased(Operation *op) override { erased.insert(op); }
1007 
1008  void notifyBlockErased(Block *block) override { erased.insert(block); }
1009 
1010  private:
1011  /// Pointers to all erased operations and blocks.
1012  DenseSet<void *> erased;
1013  };
1014 
1015  //===--------------------------------------------------------------------===//
1016  // State
1017  //===--------------------------------------------------------------------===//
1018 
1019  /// MLIR context.
1021 
1022  /// A rewriter that keeps track of ops/block that were already erased and
1023  /// skips duplicate op/block erasures. This rewriter is used during the
1024  /// "cleanup" phase.
1026 
1027  // Mapping between replaced values that differ in type. This happens when
1028  // replacing a value with one of a different type.
1029  ConversionValueMapping mapping;
1030 
1031  /// Ordered list of block operations (creations, splits, motions).
1033 
1034  /// A set of operations that should no longer be considered for legalization.
1035  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1036  /// tracked separately.
1038 
1039  /// A set of operations that were replaced/erased. Such ops are not erased
1040  /// immediately but only when the dialect conversion succeeds. In the mean
1041  /// time, they should no longer be considered for legalization and any attempt
1042  /// to modify/access them is invalid rewriter API usage.
1044 
1045  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1046  /// to the corresponding rewrite objects.
1049 
1050  /// The current type converter, or nullptr if no type converter is currently
1051  /// active.
1052  const TypeConverter *currentTypeConverter = nullptr;
1053 
1054  /// A mapping of regions to type converters that should be used when
1055  /// converting the arguments of blocks within that region.
1057 
1058  /// Dialect conversion configuration.
1060 
1061 #ifndef NDEBUG
1062  /// A set of operations that have pending updates. This tracking isn't
1063  /// strictly necessary, and is thus only active during debug builds for extra
1064  /// verification.
1066 
1067  /// A logger used to emit diagnostics during the conversion process.
1068  llvm::ScopedPrinter logger{llvm::dbgs()};
1069 #endif
1070 };
1071 } // namespace detail
1072 } // namespace mlir
1073 
1074 const ConversionConfig &IRRewrite::getConfig() const {
1075  return rewriterImpl.config;
1076 }
1077 
1078 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1079  // Inform the listener about all IR modifications that have already taken
1080  // place: References to the original block have been replaced with the new
1081  // block.
1082  if (auto *listener =
1083  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1084  for (Operation *op : getNewBlock()->getUsers())
1085  listener->notifyOperationModified(op);
1086 }
1087 
1088 void BlockTypeConversionRewrite::rollback() {
1089  getNewBlock()->replaceAllUsesWith(getOrigBlock());
1090 }
1091 
1092 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1093  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
1094  if (!repl)
1095  return;
1096 
1097  if (isa<BlockArgument>(repl)) {
1098  rewriter.replaceAllUsesWith(arg, repl);
1099  return;
1100  }
1101 
1102  // If the replacement value is an operation, we check to make sure that we
1103  // don't replace uses that are within the parent operation of the
1104  // replacement value.
1105  Operation *replOp = cast<OpResult>(repl).getOwner();
1106  Block *replBlock = replOp->getBlock();
1107  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1108  Operation *user = operand.getOwner();
1109  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1110  });
1111 }
1112 
1113 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
1114 
1115 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1116  auto *listener =
1117  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1118 
1119  // Compute replacement values.
1120  SmallVector<Value> replacements =
1121  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1122  return rewriterImpl.findOrBuildReplacementValue(result, converter);
1123  });
1124 
1125  // Notify the listener that the operation is about to be replaced.
1126  if (listener)
1127  listener->notifyOperationReplaced(op, replacements);
1128 
1129  // Replace all uses with the new values.
1130  for (auto [result, newValue] :
1131  llvm::zip_equal(op->getResults(), replacements))
1132  if (newValue)
1133  rewriter.replaceAllUsesWith(result, newValue);
1134 
1135  // The original op will be erased, so remove it from the set of unlegalized
1136  // ops.
1137  if (getConfig().unlegalizedOps)
1138  getConfig().unlegalizedOps->erase(op);
1139 
1140  // Notify the listener that the operation (and its nested operations) was
1141  // erased.
1142  if (listener) {
1144  [&](Operation *op) { listener->notifyOperationErased(op); });
1145  }
1146 
1147  // Do not erase the operation yet. It may still be referenced in `mapping`.
1148  // Just unlink it for now and erase it during cleanup.
1149  op->getBlock()->getOperations().remove(op);
1150 }
1151 
1152 void ReplaceOperationRewrite::rollback() {
1153  for (auto result : op->getResults())
1154  rewriterImpl.mapping.erase({result});
1155 }
1156 
1157 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1158  rewriter.eraseOp(op);
1159 }
1160 
1161 void CreateOperationRewrite::rollback() {
1162  for (Region &region : op->getRegions()) {
1163  while (!region.getBlocks().empty())
1164  region.getBlocks().remove(region.getBlocks().begin());
1165  }
1166  op->dropAllUses();
1167  op->erase();
1168 }
1169 
1170 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1171  ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1172  const TypeConverter *converter, MaterializationKind kind, Type originalType,
1173  ValueVector mappedValues)
1174  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1175  converterAndKind(converter, kind), originalType(originalType),
1176  mappedValues(std::move(mappedValues)) {
1177  assert((!originalType || kind == MaterializationKind::Target) &&
1178  "original type is valid only for target materializations");
1179  rewriterImpl.unresolvedMaterializations[op] = this;
1180 }
1181 
1182 void UnresolvedMaterializationRewrite::rollback() {
1183  if (!mappedValues.empty())
1184  rewriterImpl.mapping.erase(mappedValues);
1185  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1186  op->erase();
1187 }
1188 
1190  // Commit all rewrites.
1191  IRRewriter rewriter(context, config.listener);
1192  // Note: New rewrites may be added during the "commit" phase and the
1193  // `rewrites` vector may reallocate.
1194  for (size_t i = 0; i < rewrites.size(); ++i)
1195  rewrites[i]->commit(rewriter);
1196 
1197  // Clean up all rewrites.
1198  for (auto &rewrite : rewrites)
1199  rewrite->cleanup(eraseRewriter);
1200 }
1201 
1202 //===----------------------------------------------------------------------===//
1203 // State Management
1204 //===----------------------------------------------------------------------===//
1205 
1207  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1208 }
1209 
1211  StringRef patternName) {
1212  // Undo any rewrites.
1213  undoRewrites(state.numRewrites, patternName);
1214 
1215  // Pop all of the recorded ignored operations that are no longer valid.
1216  while (ignoredOps.size() != state.numIgnoredOperations)
1217  ignoredOps.pop_back();
1218 
1219  while (replacedOps.size() != state.numReplacedOps)
1220  replacedOps.pop_back();
1221 }
1222 
1223 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1224  StringRef patternName) {
1225  for (auto &rewrite :
1226  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
1228  !isa<UnresolvedMaterializationRewrite>(rewrite)) {
1229  // Unresolved materializations can always be rolled back (erased).
1230  llvm::report_fatal_error("pattern '" + patternName +
1231  "' rollback of IR modifications requested");
1232  }
1233  rewrite->rollback();
1234  }
1235  rewrites.resize(numRewritesToKeep);
1236 }
1237 
1239  StringRef valueDiagTag, std::optional<Location> inputLoc,
1240  PatternRewriter &rewriter, ValueRange values,
1241  SmallVector<ValueVector> &remapped) {
1242  remapped.reserve(llvm::size(values));
1243 
1244  for (const auto &it : llvm::enumerate(values)) {
1245  Value operand = it.value();
1246  Type origType = operand.getType();
1247  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1248 
1249  if (!currentTypeConverter) {
1250  // The current pattern does not have a type converter. I.e., it does not
1251  // distinguish between legal and illegal types. For each operand, simply
1252  // pass through the most recently mapped values.
1253  remapped.push_back(mapping.lookupOrDefault(operand));
1254  continue;
1255  }
1256 
1257  // If there is no legal conversion, fail to match this pattern.
1258  SmallVector<Type, 1> legalTypes;
1259  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1260  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1261  diag << "unable to convert type for " << valueDiagTag << " #"
1262  << it.index() << ", type was " << origType;
1263  });
1264  return failure();
1265  }
1266  // If a type is converted to 0 types, there is nothing to do.
1267  if (legalTypes.empty()) {
1268  remapped.push_back({});
1269  continue;
1270  }
1271 
1272  ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
1273  if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1274  // Mapped values have the correct type or there is an existing
1275  // materialization. Or the operand is not mapped at all and has the
1276  // correct type.
1277  remapped.push_back(std::move(repl));
1278  continue;
1279  }
1280 
1281  // Create a materialization for the most recently mapped values.
1282  repl = mapping.lookupOrDefault(operand);
1284  MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1285  /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1286  /*originalType=*/origType, currentTypeConverter);
1287  remapped.push_back(castValues);
1288  }
1289  return success();
1290 }
1291 
1293  // Check to see if this operation is ignored or was replaced.
1294  return replacedOps.count(op) || ignoredOps.count(op);
1295 }
1296 
1298  // Check to see if this operation was replaced.
1299  return replacedOps.count(op);
1300 }
1301 
1302 //===----------------------------------------------------------------------===//
1303 // Type Conversion
1304 //===----------------------------------------------------------------------===//
1305 
1307  ConversionPatternRewriter &rewriter, Region *region,
1308  const TypeConverter &converter,
1309  TypeConverter::SignatureConversion *entryConversion) {
1310  regionToConverter[region] = &converter;
1311  if (region->empty())
1312  return nullptr;
1313 
1314  // Convert the arguments of each non-entry block within the region.
1315  for (Block &block :
1316  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1317  // Compute the signature for the block with the provided converter.
1318  std::optional<TypeConverter::SignatureConversion> conversion =
1319  converter.convertBlockSignature(&block);
1320  if (!conversion)
1321  return failure();
1322  // Convert the block with the computed signature.
1323  applySignatureConversion(rewriter, &block, &converter, *conversion);
1324  }
1325 
1326  // Convert the entry block. If an entry signature conversion was provided,
1327  // use that one. Otherwise, compute the signature with the type converter.
1328  if (entryConversion)
1329  return applySignatureConversion(rewriter, &region->front(), &converter,
1330  *entryConversion);
1331  std::optional<TypeConverter::SignatureConversion> conversion =
1332  converter.convertBlockSignature(&region->front());
1333  if (!conversion)
1334  return failure();
1335  return applySignatureConversion(rewriter, &region->front(), &converter,
1336  *conversion);
1337 }
1338 
1340  ConversionPatternRewriter &rewriter, Block *block,
1341  const TypeConverter *converter,
1342  TypeConverter::SignatureConversion &signatureConversion) {
1343 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1344  // A block cannot be converted multiple times.
1345  if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1346  llvm::report_fatal_error("block was already converted");
1347 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1348 
1349  OpBuilder::InsertionGuard g(rewriter);
1350 
1351  // If no arguments are being changed or added, there is nothing to do.
1352  unsigned origArgCount = block->getNumArguments();
1353  auto convertedTypes = signatureConversion.getConvertedTypes();
1354  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1355  return block;
1356 
1357  // Compute the locations of all block arguments in the new block.
1358  SmallVector<Location> newLocs(convertedTypes.size(),
1359  rewriter.getUnknownLoc());
1360  for (unsigned i = 0; i < origArgCount; ++i) {
1361  auto inputMap = signatureConversion.getInputMapping(i);
1362  if (!inputMap || inputMap->replacedWithValues())
1363  continue;
1364  Location origLoc = block->getArgument(i).getLoc();
1365  for (unsigned j = 0; j < inputMap->size; ++j)
1366  newLocs[inputMap->inputNo + j] = origLoc;
1367  }
1368 
1369  // Insert a new block with the converted block argument types and move all ops
1370  // from the old block to the new block.
1371  Block *newBlock =
1372  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1373  convertedTypes, newLocs);
1374 
1375  // If a listener is attached to the dialect conversion, ops cannot be moved
1376  // to the destination block in bulk ("fast path"). This is because at the time
1377  // the notifications are sent, it is unknown which ops were moved. Instead,
1378  // ops should be moved one-by-one ("slow path"), so that a separate
1379  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1380  // a bit more efficient, so we try to do that when possible.
1381  bool fastPath = !config.listener;
1382  if (fastPath) {
1383  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1384  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1385  } else {
1386  while (!block->empty())
1387  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1388  }
1389 
1390  // Replace all uses of the old block with the new block.
1391  block->replaceAllUsesWith(newBlock);
1392 
1393  for (unsigned i = 0; i != origArgCount; ++i) {
1394  BlockArgument origArg = block->getArgument(i);
1395  Type origArgType = origArg.getType();
1396 
1397  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1398  signatureConversion.getInputMapping(i);
1399  if (!inputMap) {
1400  // This block argument was dropped and no replacement value was provided.
1401  // Materialize a replacement value "out of thin air".
1403  MaterializationKind::Source,
1404  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1405  /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
1406  /*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
1407  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1408  continue;
1409  }
1410 
1411  if (inputMap->replacedWithValues()) {
1412  // This block argument was dropped and replacement values were provided.
1413  assert(inputMap->size == 0 &&
1414  "invalid to provide a replacement value when the argument isn't "
1415  "dropped");
1416  mapping.map(origArg, inputMap->replacementValues);
1417  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1418  continue;
1419  }
1420 
1421  // This is a 1->1+ mapping.
1422  auto replArgs =
1423  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1424  ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
1425  mapping.map(origArg, std::move(replArgVals));
1426  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1427  }
1428 
1429  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1430 
1431  // Erase the old block. (It is just unlinked for now and will be erased during
1432  // cleanup.)
1433  rewriter.eraseBlock(block);
1434 
1435  return newBlock;
1436 }
1437 
1438 //===----------------------------------------------------------------------===//
1439 // Materializations
1440 //===----------------------------------------------------------------------===//
1441 
1442 /// Build an unresolved materialization operation given an output type and set
1443 /// of input operands.
1445  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1446  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1447  Type originalType, const TypeConverter *converter,
1448  UnrealizedConversionCastOp *castOp) {
1449  assert((!originalType || kind == MaterializationKind::Target) &&
1450  "original type is valid only for target materializations");
1451  assert(TypeRange(inputs) != outputTypes &&
1452  "materialization is not necessary");
1453 
1454  // Create an unresolved materialization. We use a new OpBuilder to avoid
1455  // tracking the materialization like we do for other operations.
1456  OpBuilder builder(outputTypes.front().getContext());
1457  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1458  auto convertOp =
1459  builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
1460  if (!valuesToMap.empty())
1461  mapping.map(valuesToMap, convertOp.getResults());
1462  if (castOp)
1463  *castOp = convertOp;
1464  appendRewrite<UnresolvedMaterializationRewrite>(
1465  convertOp, converter, kind, originalType, std::move(valuesToMap));
1466  return convertOp.getResults();
1467 }
1468 
1470  Value value, const TypeConverter *converter) {
1471  // Try to find a replacement value with the same type in the conversion value
1472  // mapping. This includes cached materializations. We try to reuse those
1473  // instead of generating duplicate IR.
1474  ValueVector repl = mapping.lookupOrNull(value, value.getType());
1475  if (!repl.empty())
1476  return repl.front();
1477 
1478  // Check if the value is dead. No replacement value is needed in that case.
1479  // This is an approximate check that may have false negatives but does not
1480  // require computing and traversing an inverse mapping. (We may end up
1481  // building source materializations that are never used and that fold away.)
1482  if (llvm::all_of(value.getUsers(),
1483  [&](Operation *op) { return replacedOps.contains(op); }) &&
1484  !mapping.isMappedTo(value))
1485  return Value();
1486 
1487  // No replacement value was found. Get the latest replacement value
1488  // (regardless of the type) and build a source materialization to the
1489  // original type.
1490  repl = mapping.lookupOrNull(value);
1491  if (repl.empty()) {
1492  // No replacement value is registered in the mapping. This means that the
1493  // value is dropped and no longer needed. (If the value were still needed,
1494  // a source materialization producing a replacement value "out of thin air"
1495  // would have already been created during `replaceOp` or
1496  // `applySignatureConversion`.)
1497  return Value();
1498  }
1499 
1500  // Note: `computeInsertPoint` computes the "earliest" insertion point at
1501  // which all values in `repl` are defined. It is important to emit the
1502  // materialization at that location because the same materialization may be
1503  // reused in a different context. (That's because materializations are cached
1504  // in the conversion value mapping.) The insertion point of the
1505  // materialization must be valid for all future users that may be created
1506  // later in the conversion process.
1507  Value castValue =
1508  buildUnresolvedMaterialization(MaterializationKind::Source,
1509  computeInsertPoint(repl), value.getLoc(),
1510  /*valuesToMap=*/repl, /*inputs=*/repl,
1511  /*outputTypes=*/value.getType(),
1512  /*originalType=*/Type(), converter)
1513  .front();
1514  return castValue;
1515 }
1516 
1517 //===----------------------------------------------------------------------===//
1518 // Rewriter Notification Hooks
1519 //===----------------------------------------------------------------------===//
1520 
1522  Operation *op, OpBuilder::InsertPoint previous) {
1523  LLVM_DEBUG({
1524  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1525  << ")\n";
1526  });
1527  assert(!wasOpReplaced(op->getParentOp()) &&
1528  "attempting to insert into a block within a replaced/erased op");
1529 
1530  if (!previous.isSet()) {
1531  // This is a newly created op.
1532  appendRewrite<CreateOperationRewrite>(op);
1533  return;
1534  }
1535  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1536  ? nullptr
1537  : &*previous.getPoint();
1538  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1539 }
1540 
1542  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1543  assert(newValues.size() == op->getNumResults());
1544  assert(!ignoredOps.contains(op) && "operation was already replaced");
1545 
1546  // Check if replaced op is an unresolved materialization, i.e., an
1547  // unrealized_conversion_cast op that was created by the conversion driver.
1548  bool isUnresolvedMaterialization = false;
1549  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1550  if (unresolvedMaterializations.contains(castOp))
1551  isUnresolvedMaterialization = true;
1552 
1553  // Create mappings for each of the new result values.
1554  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
1555  if (repl.empty()) {
1556  // This result was dropped and no replacement value was provided.
1557  if (isUnresolvedMaterialization) {
1558  // Do not create another materializations if we are erasing a
1559  // materialization.
1560  continue;
1561  }
1562 
1563  // Materialize a replacement value "out of thin air".
1565  MaterializationKind::Source, computeInsertPoint(result),
1566  result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
1567  /*outputTypes=*/result.getType(), /*originalType=*/Type(),
1569  continue;
1570  } else {
1571  // Make sure that the user does not mess with unresolved materializations
1572  // that were inserted by the conversion driver. We keep track of these
1573  // ops in internal data structures. Erasing them must be allowed because
1574  // this can happen when the user is erasing an entire block (including
1575  // its body). But replacing them with another value should be forbidden
1576  // to avoid problems with the `mapping`.
1577  assert(!isUnresolvedMaterialization &&
1578  "attempting to replace an unresolved materialization");
1579  }
1580 
1581  // Remap result to replacement value.
1582  if (repl.empty())
1583  continue;
1584  mapping.map(static_cast<Value>(result), std::move(repl));
1585  }
1586 
1587  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1588  // Mark this operation and all nested ops as replaced.
1589  op->walk([&](Operation *op) { replacedOps.insert(op); });
1590 }
1591 
1593  appendRewrite<EraseBlockRewrite>(block);
1594 }
1595 
1597  Block *block, Region *previous, Region::iterator previousIt) {
1598  assert(!wasOpReplaced(block->getParentOp()) &&
1599  "attempting to insert into a region within a replaced/erased op");
1600  LLVM_DEBUG(
1601  {
1602  Operation *parent = block->getParentOp();
1603  if (parent) {
1604  logger.startLine() << "** Insert Block into : '" << parent->getName()
1605  << "'(" << parent << ")\n";
1606  } else {
1607  logger.startLine()
1608  << "** Insert Block into detached Region (nullptr parent op)'";
1609  }
1610  });
1611 
1612  if (!previous) {
1613  // This is a newly created block.
1614  appendRewrite<CreateBlockRewrite>(block);
1615  return;
1616  }
1617  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1618  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1619 }
1620 
1622  Block *block, Block *srcBlock, Block::iterator before) {
1623  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1624 }
1625 
1627  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1628  LLVM_DEBUG({
1630  reasonCallback(diag);
1631  logger.startLine() << "** Failure : " << diag.str() << "\n";
1632  if (config.notifyCallback)
1634  });
1635 }
1636 
1637 //===----------------------------------------------------------------------===//
1638 // ConversionPatternRewriter
1639 //===----------------------------------------------------------------------===//
1640 
1641 ConversionPatternRewriter::ConversionPatternRewriter(
1642  MLIRContext *ctx, const ConversionConfig &config)
1643  : PatternRewriter(ctx),
1644  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1645  setListener(impl.get());
1646 }
1647 
1649 
1651  assert(op && newOp && "expected non-null op");
1652  replaceOp(op, newOp->getResults());
1653 }
1654 
1656  assert(op->getNumResults() == newValues.size() &&
1657  "incorrect # of replacement values");
1658  LLVM_DEBUG({
1659  impl->logger.startLine()
1660  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1661  });
1663  llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
1664  return v ? SmallVector<Value>{v} : SmallVector<Value>();
1665  });
1666  impl->notifyOpReplaced(op, std::move(newVals));
1667 }
1668 
1670  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1671  assert(op->getNumResults() == newValues.size() &&
1672  "incorrect # of replacement values");
1673  LLVM_DEBUG({
1674  impl->logger.startLine()
1675  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1676  });
1677  impl->notifyOpReplaced(op, std::move(newValues));
1678 }
1679 
1681  LLVM_DEBUG({
1682  impl->logger.startLine()
1683  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1684  });
1685  SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
1686  impl->notifyOpReplaced(op, std::move(nullRepls));
1687 }
1688 
1690  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1691  "attempting to erase a block within a replaced/erased op");
1692 
1693  // Mark all ops for erasure.
1694  for (Operation &op : *block)
1695  eraseOp(&op);
1696 
1697  // Unlink the block from its parent region. The block is kept in the rewrite
1698  // object and will be actually destroyed when rewrites are applied. This
1699  // allows us to keep the operations in the block live and undo the removal by
1700  // re-inserting the block.
1701  impl->notifyBlockIsBeingErased(block);
1702  block->getParent()->getBlocks().remove(block);
1703 }
1704 
1706  Block *block, TypeConverter::SignatureConversion &conversion,
1707  const TypeConverter *converter) {
1708  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1709  "attempting to apply a signature conversion to a block within a "
1710  "replaced/erased op");
1711  return impl->applySignatureConversion(*this, block, converter, conversion);
1712 }
1713 
1715  Region *region, const TypeConverter &converter,
1716  TypeConverter::SignatureConversion *entryConversion) {
1717  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1718  "attempting to apply a signature conversion to a block within a "
1719  "replaced/erased op");
1720  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1721 }
1722 
1724  Value to) {
1725  LLVM_DEBUG({
1726  impl->logger.startLine() << "** Replace Argument : '" << from << "'";
1727  if (Operation *parentOp = from.getOwner()->getParentOp()) {
1728  impl->logger.getOStream() << " (in region of '" << parentOp->getName()
1729  << "' (" << parentOp << ")\n";
1730  } else {
1731  impl->logger.getOStream() << " (unlinked block)\n";
1732  }
1733  });
1734  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
1735  impl->currentTypeConverter);
1736  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1737 }
1738 
1740  SmallVector<ValueVector> remappedValues;
1741  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1742  remappedValues)))
1743  return nullptr;
1744  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
1745  return remappedValues.front().front();
1746 }
1747 
1748 LogicalResult
1750  SmallVectorImpl<Value> &results) {
1751  if (keys.empty())
1752  return success();
1753  SmallVector<ValueVector> remapped;
1754  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1755  remapped)))
1756  return failure();
1757  for (const auto &values : remapped) {
1758  assert(values.size() == 1 && "1:N conversion not supported");
1759  results.push_back(values.front());
1760  }
1761  return success();
1762 }
1763 
1765  Block::iterator before,
1766  ValueRange argValues) {
1767 #ifndef NDEBUG
1768  assert(argValues.size() == source->getNumArguments() &&
1769  "incorrect # of argument replacement values");
1770  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1771  "attempting to inline a block from a replaced/erased op");
1772  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1773  "attempting to inline a block into a replaced/erased op");
1774  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1775  // The source block will be deleted, so it should not have any users (i.e.,
1776  // there should be no predecessors).
1777  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1778  "expected 'source' to have no predecessors");
1779 #endif // NDEBUG
1780 
1781  // If a listener is attached to the dialect conversion, ops cannot be moved
1782  // to the destination block in bulk ("fast path"). This is because at the time
1783  // the notifications are sent, it is unknown which ops were moved. Instead,
1784  // ops should be moved one-by-one ("slow path"), so that a separate
1785  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1786  // a bit more efficient, so we try to do that when possible.
1787  bool fastPath = !impl->config.listener;
1788 
1789  if (fastPath)
1790  impl->notifyBlockBeingInlined(dest, source, before);
1791 
1792  // Replace all uses of block arguments.
1793  for (auto it : llvm::zip(source->getArguments(), argValues))
1794  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1795 
1796  if (fastPath) {
1797  // Move all ops at once.
1798  dest->getOperations().splice(before, source->getOperations());
1799  } else {
1800  // Move op by op.
1801  while (!source->empty())
1802  moveOpBefore(&source->front(), dest, before);
1803  }
1804 
1805  // Erase the source block.
1806  eraseBlock(source);
1807 }
1808 
1810  assert(!impl->wasOpReplaced(op) &&
1811  "attempting to modify a replaced/erased op");
1812 #ifndef NDEBUG
1813  impl->pendingRootUpdates.insert(op);
1814 #endif
1815  impl->appendRewrite<ModifyOperationRewrite>(op);
1816 }
1817 
1819  assert(!impl->wasOpReplaced(op) &&
1820  "attempting to modify a replaced/erased op");
1822  // There is nothing to do here, we only need to track the operation at the
1823  // start of the update.
1824 #ifndef NDEBUG
1825  assert(impl->pendingRootUpdates.erase(op) &&
1826  "operation did not have a pending in-place update");
1827 #endif
1828 }
1829 
1831 #ifndef NDEBUG
1832  assert(impl->pendingRootUpdates.erase(op) &&
1833  "operation did not have a pending in-place update");
1834 #endif
1835  // Erase the last update for this operation.
1836  auto it = llvm::find_if(
1837  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1838  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1839  return modifyRewrite && modifyRewrite->getOperation() == op;
1840  });
1841  assert(it != impl->rewrites.rend() && "no root update started on op");
1842  (*it)->rollback();
1843  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1844  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1845 }
1846 
1848  return *impl;
1849 }
1850 
1851 //===----------------------------------------------------------------------===//
1852 // ConversionPattern
1853 //===----------------------------------------------------------------------===//
1854 
1856  ArrayRef<ValueRange> operands) const {
1857  SmallVector<Value> oneToOneOperands;
1858  oneToOneOperands.reserve(operands.size());
1859  for (ValueRange operand : operands) {
1860  if (operand.size() != 1)
1861  llvm::report_fatal_error("pattern '" + getDebugName() +
1862  "' does not support 1:N conversion");
1863  oneToOneOperands.push_back(operand.front());
1864  }
1865  return oneToOneOperands;
1866 }
1867 
1868 LogicalResult
1870  PatternRewriter &rewriter) const {
1871  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1872  auto &rewriterImpl = dialectRewriter.getImpl();
1873 
1874  // Track the current conversion pattern type converter in the rewriter.
1875  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1876  getTypeConverter());
1877 
1878  // Remap the operands of the operation.
1879  SmallVector<ValueVector> remapped;
1880  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1881  op->getOperands(), remapped))) {
1882  return failure();
1883  }
1884  SmallVector<ValueRange> remappedAsRange =
1885  llvm::to_vector_of<ValueRange>(remapped);
1886  return matchAndRewrite(op, remappedAsRange, dialectRewriter);
1887 }
1888 
1889 //===----------------------------------------------------------------------===//
1890 // OperationLegalizer
1891 //===----------------------------------------------------------------------===//
1892 
1893 namespace {
1894 /// A set of rewrite patterns that can be used to legalize a given operation.
1895 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1896 
1897 /// This class defines a recursive operation legalizer.
1898 class OperationLegalizer {
1899 public:
1900  using LegalizationAction = ConversionTarget::LegalizationAction;
1901 
1902  OperationLegalizer(const ConversionTarget &targetInfo,
1904  const ConversionConfig &config);
1905 
1906  /// Returns true if the given operation is known to be illegal on the target.
1907  bool isIllegal(Operation *op) const;
1908 
1909  /// Attempt to legalize the given operation. Returns success if the operation
1910  /// was legalized, failure otherwise.
1911  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1912 
1913  /// Returns the conversion target in use by the legalizer.
1914  const ConversionTarget &getTarget() { return target; }
1915 
1916 private:
1917  /// Attempt to legalize the given operation by folding it.
1918  LogicalResult legalizeWithFold(Operation *op,
1919  ConversionPatternRewriter &rewriter);
1920 
1921  /// Attempt to legalize the given operation by applying a pattern. Returns
1922  /// success if the operation was legalized, failure otherwise.
1923  LogicalResult legalizeWithPattern(Operation *op,
1924  ConversionPatternRewriter &rewriter);
1925 
1926  /// Return true if the given pattern may be applied to the given operation,
1927  /// false otherwise.
1928  bool canApplyPattern(Operation *op, const Pattern &pattern,
1929  ConversionPatternRewriter &rewriter);
1930 
1931  /// Legalize the resultant IR after successfully applying the given pattern.
1932  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1933  ConversionPatternRewriter &rewriter,
1934  RewriterState &curState);
1935 
1936  /// Legalizes the actions registered during the execution of a pattern.
1937  LogicalResult
1938  legalizePatternBlockRewrites(Operation *op,
1939  ConversionPatternRewriter &rewriter,
1941  RewriterState &state, RewriterState &newState);
1942  LogicalResult legalizePatternCreatedOperations(
1944  RewriterState &state, RewriterState &newState);
1945  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1947  RewriterState &state,
1948  RewriterState &newState);
1949 
1950  //===--------------------------------------------------------------------===//
1951  // Cost Model
1952  //===--------------------------------------------------------------------===//
1953 
1954  /// Build an optimistic legalization graph given the provided patterns. This
1955  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1956  /// patterns for operations that are not directly legal, but may be
1957  /// transitively legal for the current target given the provided patterns.
1958  void buildLegalizationGraph(
1959  LegalizationPatterns &anyOpLegalizerPatterns,
1961 
1962  /// Compute the benefit of each node within the computed legalization graph.
1963  /// This orders the patterns within 'legalizerPatterns' based upon two
1964  /// criteria:
1965  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1966  /// represent the more direct mapping to the target.
1967  /// 2) When comparing patterns with the same legalization depth, prefer the
1968  /// pattern with the highest PatternBenefit. This allows for users to
1969  /// prefer specific legalizations over others.
1970  void computeLegalizationGraphBenefit(
1971  LegalizationPatterns &anyOpLegalizerPatterns,
1973 
1974  /// Compute the legalization depth when legalizing an operation of the given
1975  /// type.
1976  unsigned computeOpLegalizationDepth(
1977  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1979 
1980  /// Apply the conversion cost model to the given set of patterns, and return
1981  /// the smallest legalization depth of any of the patterns. See
1982  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1983  unsigned applyCostModelToPatterns(
1984  LegalizationPatterns &patterns,
1985  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1987 
1988  /// The current set of patterns that have been applied.
1989  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1990 
1991  /// The legalization information provided by the target.
1992  const ConversionTarget &target;
1993 
1994  /// The pattern applicator to use for conversions.
1995  PatternApplicator applicator;
1996 
1997  /// Dialect conversion configuration.
1998  const ConversionConfig &config;
1999 };
2000 } // namespace
2001 
2002 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
2004  const ConversionConfig &config)
2005  : target(targetInfo), applicator(patterns), config(config) {
2006  // The set of patterns that can be applied to illegal operations to transform
2007  // them into legal ones.
2009  LegalizationPatterns anyOpLegalizerPatterns;
2010 
2011  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2012  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2013 }
2014 
2015 bool OperationLegalizer::isIllegal(Operation *op) const {
2016  return target.isIllegal(op);
2017 }
2018 
2019 LogicalResult
2020 OperationLegalizer::legalize(Operation *op,
2021  ConversionPatternRewriter &rewriter) {
2022 #ifndef NDEBUG
2023  const char *logLineComment =
2024  "//===-------------------------------------------===//\n";
2025 
2026  auto &logger = rewriter.getImpl().logger;
2027 #endif
2028  LLVM_DEBUG({
2029  logger.getOStream() << "\n";
2030  logger.startLine() << logLineComment;
2031  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
2032  << op << ") {\n";
2033  logger.indent();
2034 
2035  // If the operation has no regions, just print it here.
2036  if (op->getNumRegions() == 0) {
2037  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2038  logger.getOStream() << "\n\n";
2039  }
2040  });
2041 
2042  // Check if this operation is legal on the target.
2043  if (auto legalityInfo = target.isLegal(op)) {
2044  LLVM_DEBUG({
2045  logSuccess(
2046  logger, "operation marked legal by the target{0}",
2047  legalityInfo->isRecursivelyLegal
2048  ? "; NOTE: operation is recursively legal; skipping internals"
2049  : "");
2050  logger.startLine() << logLineComment;
2051  });
2052 
2053  // If this operation is recursively legal, mark its children as ignored so
2054  // that we don't consider them for legalization.
2055  if (legalityInfo->isRecursivelyLegal) {
2056  op->walk([&](Operation *nested) {
2057  if (op != nested)
2058  rewriter.getImpl().ignoredOps.insert(nested);
2059  });
2060  }
2061 
2062  return success();
2063  }
2064 
2065  // Check to see if the operation is ignored and doesn't need to be converted.
2066  if (rewriter.getImpl().isOpIgnored(op)) {
2067  LLVM_DEBUG({
2068  logSuccess(logger, "operation marked 'ignored' during conversion");
2069  logger.startLine() << logLineComment;
2070  });
2071  return success();
2072  }
2073 
2074  // If the operation isn't legal, try to fold it in-place.
2075  // TODO: Should we always try to do this, even if the op is
2076  // already legal?
2077  if (succeeded(legalizeWithFold(op, rewriter))) {
2078  LLVM_DEBUG({
2079  logSuccess(logger, "operation was folded");
2080  logger.startLine() << logLineComment;
2081  });
2082  return success();
2083  }
2084 
2085  // Otherwise, we need to apply a legalization pattern to this operation.
2086  if (succeeded(legalizeWithPattern(op, rewriter))) {
2087  LLVM_DEBUG({
2088  logSuccess(logger, "");
2089  logger.startLine() << logLineComment;
2090  });
2091  return success();
2092  }
2093 
2094  LLVM_DEBUG({
2095  logFailure(logger, "no matched legalization pattern");
2096  logger.startLine() << logLineComment;
2097  });
2098  return failure();
2099 }
2100 
2101 LogicalResult
2102 OperationLegalizer::legalizeWithFold(Operation *op,
2103  ConversionPatternRewriter &rewriter) {
2104  auto &rewriterImpl = rewriter.getImpl();
2105  LLVM_DEBUG({
2106  rewriterImpl.logger.startLine() << "* Fold {\n";
2107  rewriterImpl.logger.indent();
2108  });
2109  (void)rewriterImpl;
2110 
2111  // Try to fold the operation.
2112  SmallVector<Value, 2> replacementValues;
2113  SmallVector<Operation *, 2> newOps;
2114  rewriter.setInsertionPoint(op);
2115  if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2116  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2117  return failure();
2118  }
2119 
2120  // An empty list of replacement values indicates that the fold was in-place.
2121  // As the operation changed, a new legalization needs to be attempted.
2122  if (replacementValues.empty())
2123  return legalize(op, rewriter);
2124 
2125  // Recursively legalize any new constant operations.
2126  for (Operation *newOp : newOps) {
2127  if (failed(legalize(newOp, rewriter))) {
2128  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2129  "failed to legalize generated constant '{0}'",
2130  newOp->getName()));
2131  // Legalization failed: erase all materialized constants.
2132  for (Operation *op : newOps)
2133  rewriter.eraseOp(op);
2134  return failure();
2135  }
2136  }
2137 
2138  // Insert a replacement for 'op' with the folded replacement values.
2139  rewriter.replaceOp(op, replacementValues);
2140 
2141  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2142  return success();
2143 }
2144 
2145 LogicalResult
2146 OperationLegalizer::legalizeWithPattern(Operation *op,
2147  ConversionPatternRewriter &rewriter) {
2148  auto &rewriterImpl = rewriter.getImpl();
2149 
2150  // Functor that returns if the given pattern may be applied.
2151  auto canApply = [&](const Pattern &pattern) {
2152  bool canApply = canApplyPattern(op, pattern, rewriter);
2153  if (canApply && config.listener)
2154  config.listener->notifyPatternBegin(pattern, op);
2155  return canApply;
2156  };
2157 
2158  // Functor that cleans up the rewriter state after a pattern failed to match.
2159  RewriterState curState = rewriterImpl.getCurrentState();
2160  auto onFailure = [&](const Pattern &pattern) {
2161  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2162  LLVM_DEBUG({
2163  logFailure(rewriterImpl.logger, "pattern failed to match");
2164  if (rewriterImpl.config.notifyCallback) {
2166  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2167  << "\" on op:\n"
2168  << *op;
2169  rewriterImpl.config.notifyCallback(diag);
2170  }
2171  });
2172  if (config.listener)
2173  config.listener->notifyPatternEnd(pattern, failure());
2174  rewriterImpl.resetState(curState, pattern.getDebugName());
2175  appliedPatterns.erase(&pattern);
2176  };
2177 
2178  // Functor that performs additional legalization when a pattern is
2179  // successfully applied.
2180  auto onSuccess = [&](const Pattern &pattern) {
2181  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2182  auto result = legalizePatternResult(op, pattern, rewriter, curState);
2183  appliedPatterns.erase(&pattern);
2184  if (failed(result)) {
2185  if (!rewriterImpl.config.allowPatternRollback)
2186  op->emitError("pattern '")
2187  << pattern.getDebugName()
2188  << "' produced IR that could not be legalized";
2189  rewriterImpl.resetState(curState, pattern.getDebugName());
2190  }
2191  if (config.listener)
2192  config.listener->notifyPatternEnd(pattern, result);
2193  return result;
2194  };
2195 
2196  // Try to match and rewrite a pattern on this operation.
2197  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2198  onSuccess);
2199 }
2200 
2201 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2202  ConversionPatternRewriter &rewriter) {
2203  LLVM_DEBUG({
2204  auto &os = rewriter.getImpl().logger;
2205  os.getOStream() << "\n";
2206  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2207  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2208  os.getOStream() << ")' {\n";
2209  os.indent();
2210  });
2211 
2212  // Ensure that we don't cycle by not allowing the same pattern to be
2213  // applied twice in the same recursion stack if it is not known to be safe.
2214  if (!pattern.hasBoundedRewriteRecursion() &&
2215  !appliedPatterns.insert(&pattern).second) {
2216  LLVM_DEBUG(
2217  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2218  return false;
2219  }
2220  return true;
2221 }
2222 
2223 LogicalResult
2224 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2225  ConversionPatternRewriter &rewriter,
2226  RewriterState &curState) {
2227  auto &impl = rewriter.getImpl();
2228  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2229 
2230 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2231  // Check that the root was either replaced or updated in place.
2232  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2233  auto replacedRoot = [&] {
2234  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2235  };
2236  auto updatedRootInPlace = [&] {
2237  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2238  };
2239  if (!replacedRoot() && !updatedRootInPlace())
2240  llvm::report_fatal_error("expected pattern to replace the root operation");
2241 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2242 
2243  // Legalize each of the actions registered during application.
2244  RewriterState newState = impl.getCurrentState();
2245  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2246  newState)) ||
2247  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2248  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2249  newState))) {
2250  return failure();
2251  }
2252 
2253  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2254  return success();
2255 }
2256 
2257 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2258  Operation *op, ConversionPatternRewriter &rewriter,
2259  ConversionPatternRewriterImpl &impl, RewriterState &state,
2260  RewriterState &newState) {
2261  SmallPtrSet<Operation *, 16> operationsToIgnore;
2262 
2263  // If the pattern moved or created any blocks, make sure the types of block
2264  // arguments get legalized.
2265  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2266  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2267  if (!rewrite)
2268  continue;
2269  Block *block = rewrite->getBlock();
2270  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2271  ReplaceBlockArgRewrite>(rewrite))
2272  continue;
2273  // Only check blocks outside of the current operation.
2274  Operation *parentOp = block->getParentOp();
2275  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2276  continue;
2277 
2278  // If the region of the block has a type converter, try to convert the block
2279  // directly.
2280  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2281  std::optional<TypeConverter::SignatureConversion> conversion =
2282  converter->convertBlockSignature(block);
2283  if (!conversion) {
2284  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2285  "block"));
2286  return failure();
2287  }
2288  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2289  continue;
2290  }
2291 
2292  // Otherwise, check that this operation isn't one generated by this pattern.
2293  // This is because we will attempt to legalize the parent operation, and
2294  // blocks in regions created by this pattern will already be legalized later
2295  // on. If we haven't built the set yet, build it now.
2296  if (operationsToIgnore.empty()) {
2297  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2298  ++i) {
2299  auto *createOp =
2300  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2301  if (!createOp)
2302  continue;
2303  operationsToIgnore.insert(createOp->getOperation());
2304  }
2305  }
2306 
2307  // If this operation should be considered for re-legalization, try it.
2308  if (operationsToIgnore.insert(parentOp).second &&
2309  failed(legalize(parentOp, rewriter))) {
2310  LLVM_DEBUG(logFailure(impl.logger,
2311  "operation '{0}'({1}) became illegal after rewrite",
2312  parentOp->getName(), parentOp));
2313  return failure();
2314  }
2315  }
2316  return success();
2317 }
2318 
2319 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2321  RewriterState &state, RewriterState &newState) {
2322  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2323  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2324  if (!createOp)
2325  continue;
2326  Operation *op = createOp->getOperation();
2327  if (failed(legalize(op, rewriter))) {
2328  LLVM_DEBUG(logFailure(impl.logger,
2329  "failed to legalize generated operation '{0}'({1})",
2330  op->getName(), op));
2331  return failure();
2332  }
2333  }
2334  return success();
2335 }
2336 
2337 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2339  RewriterState &state, RewriterState &newState) {
2340  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2341  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2342  if (!rewrite)
2343  continue;
2344  Operation *op = rewrite->getOperation();
2345  if (failed(legalize(op, rewriter))) {
2346  LLVM_DEBUG(logFailure(
2347  impl.logger, "failed to legalize operation updated in-place '{0}'",
2348  op->getName()));
2349  return failure();
2350  }
2351  }
2352  return success();
2353 }
2354 
2355 //===----------------------------------------------------------------------===//
2356 // Cost Model
2357 //===----------------------------------------------------------------------===//
2358 
2359 void OperationLegalizer::buildLegalizationGraph(
2360  LegalizationPatterns &anyOpLegalizerPatterns,
2361  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2362  // A mapping between an operation and a set of operations that can be used to
2363  // generate it.
2365  // A mapping between an operation and any currently invalid patterns it has.
2367  // A worklist of patterns to consider for legality.
2368  SetVector<const Pattern *> patternWorklist;
2369 
2370  // Build the mapping from operations to the parent ops that may generate them.
2371  applicator.walkAllPatterns([&](const Pattern &pattern) {
2372  std::optional<OperationName> root = pattern.getRootKind();
2373 
2374  // If the pattern has no specific root, we can't analyze the relationship
2375  // between the root op and generated operations. Given that, add all such
2376  // patterns to the legalization set.
2377  if (!root) {
2378  anyOpLegalizerPatterns.push_back(&pattern);
2379  return;
2380  }
2381 
2382  // Skip operations that are always known to be legal.
2383  if (target.getOpAction(*root) == LegalizationAction::Legal)
2384  return;
2385 
2386  // Add this pattern to the invalid set for the root op and record this root
2387  // as a parent for any generated operations.
2388  invalidPatterns[*root].insert(&pattern);
2389  for (auto op : pattern.getGeneratedOps())
2390  parentOps[op].insert(*root);
2391 
2392  // Add this pattern to the worklist.
2393  patternWorklist.insert(&pattern);
2394  });
2395 
2396  // If there are any patterns that don't have a specific root kind, we can't
2397  // make direct assumptions about what operations will never be legalized.
2398  // Note: Technically we could, but it would require an analysis that may
2399  // recurse into itself. It would be better to perform this kind of filtering
2400  // at a higher level than here anyways.
2401  if (!anyOpLegalizerPatterns.empty()) {
2402  for (const Pattern *pattern : patternWorklist)
2403  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2404  return;
2405  }
2406 
2407  while (!patternWorklist.empty()) {
2408  auto *pattern = patternWorklist.pop_back_val();
2409 
2410  // Check to see if any of the generated operations are invalid.
2411  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2412  std::optional<LegalizationAction> action = target.getOpAction(op);
2413  return !legalizerPatterns.count(op) &&
2414  (!action || action == LegalizationAction::Illegal);
2415  }))
2416  continue;
2417 
2418  // Otherwise, if all of the generated operation are valid, this op is now
2419  // legal so add all of the child patterns to the worklist.
2420  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2421  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2422 
2423  // Add any invalid patterns of the parent operations to see if they have now
2424  // become legal.
2425  for (auto op : parentOps[*pattern->getRootKind()])
2426  patternWorklist.set_union(invalidPatterns[op]);
2427  }
2428 }
2429 
2430 void OperationLegalizer::computeLegalizationGraphBenefit(
2431  LegalizationPatterns &anyOpLegalizerPatterns,
2432  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2433  // The smallest pattern depth, when legalizing an operation.
2434  DenseMap<OperationName, unsigned> minOpPatternDepth;
2435 
2436  // For each operation that is transitively legal, compute a cost for it.
2437  for (auto &opIt : legalizerPatterns)
2438  if (!minOpPatternDepth.count(opIt.first))
2439  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2440  legalizerPatterns);
2441 
2442  // Apply the cost model to the patterns that can match any operation. Those
2443  // with a specific operation type are already resolved when computing the op
2444  // legalization depth.
2445  if (!anyOpLegalizerPatterns.empty())
2446  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2447  legalizerPatterns);
2448 
2449  // Apply a cost model to the pattern applicator. We order patterns first by
2450  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2451  // decreasing benefit.
2452  applicator.applyCostModel([&](const Pattern &pattern) {
2453  ArrayRef<const Pattern *> orderedPatternList;
2454  if (std::optional<OperationName> rootName = pattern.getRootKind())
2455  orderedPatternList = legalizerPatterns[*rootName];
2456  else
2457  orderedPatternList = anyOpLegalizerPatterns;
2458 
2459  // If the pattern is not found, then it was removed and cannot be matched.
2460  auto *it = llvm::find(orderedPatternList, &pattern);
2461  if (it == orderedPatternList.end())
2463 
2464  // Patterns found earlier in the list have higher benefit.
2465  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2466  });
2467 }
2468 
2469 unsigned OperationLegalizer::computeOpLegalizationDepth(
2470  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2471  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2472  // Check for existing depth.
2473  auto depthIt = minOpPatternDepth.find(op);
2474  if (depthIt != minOpPatternDepth.end())
2475  return depthIt->second;
2476 
2477  // If a mapping for this operation does not exist, then this operation
2478  // is always legal. Return 0 as the depth for a directly legal operation.
2479  auto opPatternsIt = legalizerPatterns.find(op);
2480  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2481  return 0u;
2482 
2483  // Record this initial depth in case we encounter this op again when
2484  // recursively computing the depth.
2485  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2486 
2487  // Apply the cost model to the operation patterns, and update the minimum
2488  // depth.
2489  unsigned minDepth = applyCostModelToPatterns(
2490  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2491  minOpPatternDepth[op] = minDepth;
2492  return minDepth;
2493 }
2494 
2495 unsigned OperationLegalizer::applyCostModelToPatterns(
2496  LegalizationPatterns &patterns,
2497  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2498  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2499  unsigned minDepth = std::numeric_limits<unsigned>::max();
2500 
2501  // Compute the depth for each pattern within the set.
2502  SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2503  patternsByDepth.reserve(patterns.size());
2504  for (const Pattern *pattern : patterns) {
2505  unsigned depth = 1;
2506  for (auto generatedOp : pattern->getGeneratedOps()) {
2507  unsigned generatedOpDepth = computeOpLegalizationDepth(
2508  generatedOp, minOpPatternDepth, legalizerPatterns);
2509  depth = std::max(depth, generatedOpDepth + 1);
2510  }
2511  patternsByDepth.emplace_back(pattern, depth);
2512 
2513  // Update the minimum depth of the pattern list.
2514  minDepth = std::min(minDepth, depth);
2515  }
2516 
2517  // If the operation only has one legalization pattern, there is no need to
2518  // sort them.
2519  if (patternsByDepth.size() == 1)
2520  return minDepth;
2521 
2522  // Sort the patterns by those likely to be the most beneficial.
2523  llvm::stable_sort(patternsByDepth,
2524  [](const std::pair<const Pattern *, unsigned> &lhs,
2525  const std::pair<const Pattern *, unsigned> &rhs) {
2526  // First sort by the smaller pattern legalization
2527  // depth.
2528  if (lhs.second != rhs.second)
2529  return lhs.second < rhs.second;
2530 
2531  // Then sort by the larger pattern benefit.
2532  auto lhsBenefit = lhs.first->getBenefit();
2533  auto rhsBenefit = rhs.first->getBenefit();
2534  return lhsBenefit > rhsBenefit;
2535  });
2536 
2537  // Update the legalization pattern to use the new sorted list.
2538  patterns.clear();
2539  for (auto &patternIt : patternsByDepth)
2540  patterns.push_back(patternIt.first);
2541  return minDepth;
2542 }
2543 
2544 //===----------------------------------------------------------------------===//
2545 // OperationConverter
2546 //===----------------------------------------------------------------------===//
2547 namespace {
2548 enum OpConversionMode {
2549  /// In this mode, the conversion will ignore failed conversions to allow
2550  /// illegal operations to co-exist in the IR.
2551  Partial,
2552 
2553  /// In this mode, all operations must be legal for the given target for the
2554  /// conversion to succeed.
2555  Full,
2556 
2557  /// In this mode, operations are analyzed for legality. No actual rewrites are
2558  /// applied to the operations on success.
2559  Analysis,
2560 };
2561 } // namespace
2562 
2563 namespace mlir {
2564 // This class converts operations to a given conversion target via a set of
2565 // rewrite patterns. The conversion behaves differently depending on the
2566 // conversion mode.
2568  explicit OperationConverter(const ConversionTarget &target,
2570  const ConversionConfig &config,
2571  OpConversionMode mode)
2572  : config(config), opLegalizer(target, patterns, this->config),
2573  mode(mode) {}
2574 
2575  /// Converts the given operations to the conversion target.
2576  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2577 
2578 private:
2579  /// Converts an operation with the given rewriter.
2580  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2581 
2582  /// Dialect conversion configuration.
2583  ConversionConfig config;
2584 
2585  /// The legalizer to use when converting operations.
2586  OperationLegalizer opLegalizer;
2587 
2588  /// The conversion mode to use when legalizing operations.
2589  OpConversionMode mode;
2590 };
2591 } // namespace mlir
2592 
2593 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2594  Operation *op) {
2595  // Legalize the given operation.
2596  if (failed(opLegalizer.legalize(op, rewriter))) {
2597  // Handle the case of a failed conversion for each of the different modes.
2598  // Full conversions expect all operations to be converted.
2599  if (mode == OpConversionMode::Full)
2600  return op->emitError()
2601  << "failed to legalize operation '" << op->getName() << "'";
2602  // Partial conversions allow conversions to fail iff the operation was not
2603  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2604  // set, non-legalizable ops are added to that set.
2605  if (mode == OpConversionMode::Partial) {
2606  if (opLegalizer.isIllegal(op))
2607  return op->emitError()
2608  << "failed to legalize operation '" << op->getName()
2609  << "' that was explicitly marked illegal";
2610  if (config.unlegalizedOps)
2611  config.unlegalizedOps->insert(op);
2612  }
2613  } else if (mode == OpConversionMode::Analysis) {
2614  // Analysis conversions don't fail if any operations fail to legalize,
2615  // they are only interested in the operations that were successfully
2616  // legalized.
2617  if (config.legalizableOps)
2618  config.legalizableOps->insert(op);
2619  }
2620  return success();
2621 }
2622 
2623 static LogicalResult
2625  UnresolvedMaterializationRewrite *rewrite) {
2626  UnrealizedConversionCastOp op = rewrite->getOperation();
2627  assert(!op.use_empty() &&
2628  "expected that dead materializations have already been DCE'd");
2629  Operation::operand_range inputOperands = op.getOperands();
2630 
2631  // Try to materialize the conversion.
2632  if (const TypeConverter *converter = rewrite->getConverter()) {
2633  rewriter.setInsertionPoint(op);
2634  SmallVector<Value> newMaterialization;
2635  switch (rewrite->getMaterializationKind()) {
2636  case MaterializationKind::Target:
2637  newMaterialization = converter->materializeTargetConversion(
2638  rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2639  rewrite->getOriginalType());
2640  break;
2641  case MaterializationKind::Source:
2642  assert(op->getNumResults() == 1 && "expected single result");
2643  Value sourceMat = converter->materializeSourceConversion(
2644  rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2645  if (sourceMat)
2646  newMaterialization.push_back(sourceMat);
2647  break;
2648  }
2649  if (!newMaterialization.empty()) {
2650 #ifndef NDEBUG
2651  ValueRange newMaterializationRange(newMaterialization);
2652  assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
2653  "materialization callback produced value of incorrect type");
2654 #endif // NDEBUG
2655  rewriter.replaceOp(op, newMaterialization);
2656  return success();
2657  }
2658  }
2659 
2660  InFlightDiagnostic diag = op->emitError()
2661  << "failed to legalize unresolved materialization "
2662  "from ("
2663  << inputOperands.getTypes() << ") to ("
2664  << op.getResultTypes()
2665  << ") that remained live after conversion";
2666  diag.attachNote(op->getUsers().begin()->getLoc())
2667  << "see existing live user here: " << *op->getUsers().begin();
2668  return failure();
2669 }
2670 
2672  if (ops.empty())
2673  return success();
2674  const ConversionTarget &target = opLegalizer.getTarget();
2675 
2676  // Compute the set of operations and blocks to convert.
2677  SmallVector<Operation *> toConvert;
2678  for (auto *op : ops) {
2680  [&](Operation *op) {
2681  toConvert.push_back(op);
2682  // Don't check this operation's children for conversion if the
2683  // operation is recursively legal.
2684  auto legalityInfo = target.isLegal(op);
2685  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2686  return WalkResult::skip();
2687  return WalkResult::advance();
2688  });
2689  }
2690 
2691  // Convert each operation and discard rewrites on failure.
2692  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2693  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2694 
2695  for (auto *op : toConvert) {
2696  if (failed(convert(rewriter, op))) {
2697  // Dialect conversion failed.
2698  if (rewriterImpl.config.allowPatternRollback) {
2699  // Rollback is allowed: restore the original IR.
2700  rewriterImpl.undoRewrites();
2701  } else {
2702  // Rollback is not allowed: apply all modifications that have been
2703  // performed so far.
2704  rewriterImpl.applyRewrites();
2705  }
2706  return failure();
2707  }
2708  }
2709 
2710  // After a successful conversion, apply rewrites.
2711  rewriterImpl.applyRewrites();
2712 
2713  // Gather all unresolved materializations.
2716  &materializations = rewriterImpl.unresolvedMaterializations;
2717  for (auto it : materializations) {
2718  if (rewriterImpl.eraseRewriter.wasErased(it.first))
2719  continue;
2720  allCastOps.push_back(it.first);
2721  }
2722 
2723  // Reconcile all UnrealizedConversionCastOps that were inserted by the
2724  // dialect conversion frameworks. (Not the one that were inserted by
2725  // patterns.)
2726  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2727  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2728 
2729  // Try to legalize all unresolved materializations.
2730  if (config.buildMaterializations) {
2731  IRRewriter rewriter(rewriterImpl.context, config.listener);
2732  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2733  auto it = materializations.find(castOp);
2734  assert(it != materializations.end() && "inconsistent state");
2735  if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2736  return failure();
2737  }
2738  }
2739 
2740  return success();
2741 }
2742 
2743 //===----------------------------------------------------------------------===//
2744 // Reconcile Unrealized Casts
2745 //===----------------------------------------------------------------------===//
2746 
2749  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2750  SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
2751  // This set is maintained only if `remainingCastOps` is provided.
2752  DenseSet<Operation *> erasedOps;
2753 
2754  // Helper function that adds all operands to the worklist that are an
2755  // unrealized_conversion_cast op result.
2756  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2757  for (Value v : castOp.getInputs())
2758  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2759  worklist.insert(inputCastOp);
2760  };
2761 
2762  // Helper function that return the unrealized_conversion_cast op that
2763  // defines all inputs of the given op (in the same order). Return "nullptr"
2764  // if there is no such op.
2765  auto getInputCast =
2766  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2767  if (castOp.getInputs().empty())
2768  return {};
2769  auto inputCastOp =
2770  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2771  if (!inputCastOp)
2772  return {};
2773  if (inputCastOp.getOutputs() != castOp.getInputs())
2774  return {};
2775  return inputCastOp;
2776  };
2777 
2778  // Process ops in the worklist bottom-to-top.
2779  while (!worklist.empty()) {
2780  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2781  if (castOp->use_empty()) {
2782  // DCE: If the op has no users, erase it. Add the operands to the
2783  // worklist to find additional DCE opportunities.
2784  enqueueOperands(castOp);
2785  if (remainingCastOps)
2786  erasedOps.insert(castOp.getOperation());
2787  castOp->erase();
2788  continue;
2789  }
2790 
2791  // Traverse the chain of input cast ops to see if an op with the same
2792  // input types can be found.
2793  UnrealizedConversionCastOp nextCast = castOp;
2794  while (nextCast) {
2795  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2796  // Found a cast where the input types match the output types of the
2797  // matched op. We can directly use those inputs and the matched op can
2798  // be removed.
2799  enqueueOperands(castOp);
2800  castOp.replaceAllUsesWith(nextCast.getInputs());
2801  if (remainingCastOps)
2802  erasedOps.insert(castOp.getOperation());
2803  castOp->erase();
2804  break;
2805  }
2806  nextCast = getInputCast(nextCast);
2807  }
2808  }
2809 
2810  if (remainingCastOps)
2811  for (UnrealizedConversionCastOp op : castOps)
2812  if (!erasedOps.contains(op.getOperation()))
2813  remainingCastOps->push_back(op);
2814 }
2815 
2816 //===----------------------------------------------------------------------===//
2817 // Type Conversion
2818 //===----------------------------------------------------------------------===//
2819 
2821  ArrayRef<Type> types) {
2822  assert(!types.empty() && "expected valid types");
2823  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2824  addInputs(types);
2825 }
2826 
2828  assert(!types.empty() &&
2829  "1->0 type remappings don't need to be added explicitly");
2830  argTypes.append(types.begin(), types.end());
2831 }
2832 
2833 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2834  unsigned newInputNo,
2835  unsigned newInputCount) {
2836  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2837  assert(newInputCount != 0 && "expected valid input count");
2838  remappedInputs[origInputNo] =
2839  InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
2840 }
2841 
2843  unsigned origInputNo, ArrayRef<Value> replacements) {
2844  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2845  remappedInputs[origInputNo] = InputMapping{
2846  origInputNo, /*size=*/0,
2847  SmallVector<Value, 1>(replacements.begin(), replacements.end())};
2848 }
2849 
2851  SmallVectorImpl<Type> &results) const {
2852  assert(t && "expected non-null type");
2853 
2854  {
2855  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2856  std::defer_lock);
2858  cacheReadLock.lock();
2859  auto existingIt = cachedDirectConversions.find(t);
2860  if (existingIt != cachedDirectConversions.end()) {
2861  if (existingIt->second)
2862  results.push_back(existingIt->second);
2863  return success(existingIt->second != nullptr);
2864  }
2865  auto multiIt = cachedMultiConversions.find(t);
2866  if (multiIt != cachedMultiConversions.end()) {
2867  results.append(multiIt->second.begin(), multiIt->second.end());
2868  return success();
2869  }
2870  }
2871  // Walk the added converters in reverse order to apply the most recently
2872  // registered first.
2873  size_t currentCount = results.size();
2874 
2875  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2876  std::defer_lock);
2877 
2878  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2879  if (std::optional<LogicalResult> result = converter(t, results)) {
2881  cacheWriteLock.lock();
2882  if (!succeeded(*result)) {
2883  assert(results.size() == currentCount &&
2884  "failed type conversion should not change results");
2885  cachedDirectConversions.try_emplace(t, nullptr);
2886  return failure();
2887  }
2888  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2889  if (newTypes.size() == 1)
2890  cachedDirectConversions.try_emplace(t, newTypes.front());
2891  else
2892  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2893  return success();
2894  } else {
2895  assert(results.size() == currentCount &&
2896  "failed type conversion should not change results");
2897  }
2898  }
2899  return failure();
2900 }
2901 
2903  // Use the multi-type result version to convert the type.
2904  SmallVector<Type, 1> results;
2905  if (failed(convertType(t, results)))
2906  return nullptr;
2907 
2908  // Check to ensure that only one type was produced.
2909  return results.size() == 1 ? results.front() : nullptr;
2910 }
2911 
2912 LogicalResult
2914  SmallVectorImpl<Type> &results) const {
2915  for (Type type : types)
2916  if (failed(convertType(type, results)))
2917  return failure();
2918  return success();
2919 }
2920 
2921 bool TypeConverter::isLegal(Type type) const {
2922  return convertType(type) == type;
2923 }
2925  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2926 }
2927 
2928 bool TypeConverter::isLegal(Region *region) const {
2929  return llvm::all_of(*region, [this](Block &block) {
2930  return isLegal(block.getArgumentTypes());
2931  });
2932 }
2933 
2934 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2935  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2936 }
2937 
2938 LogicalResult
2940  SignatureConversion &result) const {
2941  // Try to convert the given input type.
2942  SmallVector<Type, 1> convertedTypes;
2943  if (failed(convertType(type, convertedTypes)))
2944  return failure();
2945 
2946  // If this argument is being dropped, there is nothing left to do.
2947  if (convertedTypes.empty())
2948  return success();
2949 
2950  // Otherwise, add the new inputs.
2951  result.addInputs(inputNo, convertedTypes);
2952  return success();
2953 }
2954 LogicalResult
2956  SignatureConversion &result,
2957  unsigned origInputOffset) const {
2958  for (unsigned i = 0, e = types.size(); i != e; ++i)
2959  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2960  return failure();
2961  return success();
2962 }
2963 
2965  Location loc, Type resultType,
2966  ValueRange inputs) const {
2967  for (const SourceMaterializationCallbackFn &fn :
2968  llvm::reverse(sourceMaterializations))
2969  if (Value result = fn(builder, resultType, inputs, loc))
2970  return result;
2971  return nullptr;
2972 }
2973 
2975  Location loc, Type resultType,
2976  ValueRange inputs,
2977  Type originalType) const {
2979  builder, loc, TypeRange(resultType), inputs, originalType);
2980  if (result.empty())
2981  return nullptr;
2982  assert(result.size() == 1 && "expected single result");
2983  return result.front();
2984 }
2985 
2987  OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
2988  Type originalType) const {
2989  for (const TargetMaterializationCallbackFn &fn :
2990  llvm::reverse(targetMaterializations)) {
2991  SmallVector<Value> result =
2992  fn(builder, resultTypes, inputs, loc, originalType);
2993  if (result.empty())
2994  continue;
2995  assert(TypeRange(ValueRange(result)) == resultTypes &&
2996  "callback produced incorrect number of values or values with "
2997  "incorrect types");
2998  return result;
2999  }
3000  return {};
3001 }
3002 
3003 std::optional<TypeConverter::SignatureConversion>
3005  SignatureConversion conversion(block->getNumArguments());
3006  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3007  return std::nullopt;
3008  return conversion;
3009 }
3010 
3011 //===----------------------------------------------------------------------===//
3012 // Type attribute conversion
3013 //===----------------------------------------------------------------------===//
3016  return AttributeConversionResult(attr, resultTag);
3017 }
3018 
3021  return AttributeConversionResult(nullptr, naTag);
3022 }
3023 
3026  return AttributeConversionResult(nullptr, abortTag);
3027 }
3028 
3030  return impl.getInt() == resultTag;
3031 }
3032 
3034  return impl.getInt() == naTag;
3035 }
3036 
3038  return impl.getInt() == abortTag;
3039 }
3040 
3042  assert(hasResult() && "Cannot get result from N/A or abort");
3043  return impl.getPointer();
3044 }
3045 
3046 std::optional<Attribute>
3048  for (const TypeAttributeConversionCallbackFn &fn :
3049  llvm::reverse(typeAttributeConversions)) {
3050  AttributeConversionResult res = fn(type, attr);
3051  if (res.hasResult())
3052  return res.getResult();
3053  if (res.isAbort())
3054  return std::nullopt;
3055  }
3056  return std::nullopt;
3057 }
3058 
3059 //===----------------------------------------------------------------------===//
3060 // FunctionOpInterfaceSignatureConversion
3061 //===----------------------------------------------------------------------===//
3062 
3063 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3064  const TypeConverter &typeConverter,
3065  ConversionPatternRewriter &rewriter) {
3066  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3067  if (!type)
3068  return failure();
3069 
3070  // Convert the original function types.
3071  TypeConverter::SignatureConversion result(type.getNumInputs());
3072  SmallVector<Type, 1> newResults;
3073  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3074  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3075  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3076  typeConverter, &result)))
3077  return failure();
3078 
3079  // Update the function signature in-place.
3080  auto newType = FunctionType::get(rewriter.getContext(),
3081  result.getConvertedTypes(), newResults);
3082 
3083  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3084 
3085  return success();
3086 }
3087 
3088 /// Create a default conversion pattern that rewrites the type signature of a
3089 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3090 /// represent their type.
3091 namespace {
3092 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3093  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3094  MLIRContext *ctx,
3095  const TypeConverter &converter)
3096  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3097 
3098  LogicalResult
3099  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3100  ConversionPatternRewriter &rewriter) const override {
3101  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3102  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3103  }
3104 };
3105 
3106 struct AnyFunctionOpInterfaceSignatureConversion
3107  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3109 
3110  LogicalResult
3111  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3112  ConversionPatternRewriter &rewriter) const override {
3113  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3114  }
3115 };
3116 } // namespace
3117 
3118 FailureOr<Operation *>
3120  const TypeConverter &converter,
3121  ConversionPatternRewriter &rewriter) {
3122  assert(op && "Invalid op");
3123  Location loc = op->getLoc();
3124  if (converter.isLegal(op))
3125  return rewriter.notifyMatchFailure(loc, "op already legal");
3126 
3127  OperationState newOp(loc, op->getName());
3128  newOp.addOperands(operands);
3129 
3130  SmallVector<Type> newResultTypes;
3131  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3132  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3133 
3134  newOp.addTypes(newResultTypes);
3135  newOp.addAttributes(op->getAttrs());
3136  return rewriter.create(newOp);
3137 }
3138 
3140  StringRef functionLikeOpName, RewritePatternSet &patterns,
3141  const TypeConverter &converter) {
3142  patterns.add<FunctionOpInterfaceSignatureConversion>(
3143  functionLikeOpName, patterns.getContext(), converter);
3144 }
3145 
3147  RewritePatternSet &patterns, const TypeConverter &converter) {
3148  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3149  converter, patterns.getContext());
3150 }
3151 
3152 //===----------------------------------------------------------------------===//
3153 // ConversionTarget
3154 //===----------------------------------------------------------------------===//
3155 
3157  LegalizationAction action) {
3158  legalOperations[op].action = action;
3159 }
3160 
3162  LegalizationAction action) {
3163  for (StringRef dialect : dialectNames)
3164  legalDialects[dialect] = action;
3165 }
3166 
3168  -> std::optional<LegalizationAction> {
3169  std::optional<LegalizationInfo> info = getOpInfo(op);
3170  return info ? info->action : std::optional<LegalizationAction>();
3171 }
3172 
3174  -> std::optional<LegalOpDetails> {
3175  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3176  if (!info)
3177  return std::nullopt;
3178 
3179  // Returns true if this operation instance is known to be legal.
3180  auto isOpLegal = [&] {
3181  // Handle dynamic legality either with the provided legality function.
3182  if (info->action == LegalizationAction::Dynamic) {
3183  std::optional<bool> result = info->legalityFn(op);
3184  if (result)
3185  return *result;
3186  }
3187 
3188  // Otherwise, the operation is only legal if it was marked 'Legal'.
3189  return info->action == LegalizationAction::Legal;
3190  };
3191  if (!isOpLegal())
3192  return std::nullopt;
3193 
3194  // This operation is legal, compute any additional legality information.
3195  LegalOpDetails legalityDetails;
3196  if (info->isRecursivelyLegal) {
3197  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3198  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3199  legalityDetails.isRecursivelyLegal =
3200  legalityFnIt->second(op).value_or(true);
3201  } else {
3202  legalityDetails.isRecursivelyLegal = true;
3203  }
3204  }
3205  return legalityDetails;
3206 }
3207 
3209  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3210  if (!info)
3211  return false;
3212 
3213  if (info->action == LegalizationAction::Dynamic) {
3214  std::optional<bool> result = info->legalityFn(op);
3215  if (!result)
3216  return false;
3217 
3218  return !(*result);
3219  }
3220 
3221  return info->action == LegalizationAction::Illegal;
3222 }
3223 
3227  if (!oldCallback)
3228  return newCallback;
3229 
3230  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3231  Operation *op) -> std::optional<bool> {
3232  if (std::optional<bool> result = newCl(op))
3233  return *result;
3234 
3235  return oldCl(op);
3236  };
3237  return chain;
3238 }
3239 
3240 void ConversionTarget::setLegalityCallback(
3241  OperationName name, const DynamicLegalityCallbackFn &callback) {
3242  assert(callback && "expected valid legality callback");
3243  auto *infoIt = legalOperations.find(name);
3244  assert(infoIt != legalOperations.end() &&
3245  infoIt->second.action == LegalizationAction::Dynamic &&
3246  "expected operation to already be marked as dynamically legal");
3247  infoIt->second.legalityFn =
3248  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3249 }
3250 
3252  OperationName name, const DynamicLegalityCallbackFn &callback) {
3253  auto *infoIt = legalOperations.find(name);
3254  assert(infoIt != legalOperations.end() &&
3255  infoIt->second.action != LegalizationAction::Illegal &&
3256  "expected operation to already be marked as legal");
3257  infoIt->second.isRecursivelyLegal = true;
3258  if (callback)
3259  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3260  std::move(opRecursiveLegalityFns[name]), callback);
3261  else
3262  opRecursiveLegalityFns.erase(name);
3263 }
3264 
3265 void ConversionTarget::setLegalityCallback(
3266  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3267  assert(callback && "expected valid legality callback");
3268  for (StringRef dialect : dialects)
3269  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3270  std::move(dialectLegalityFns[dialect]), callback);
3271 }
3272 
3273 void ConversionTarget::setLegalityCallback(
3274  const DynamicLegalityCallbackFn &callback) {
3275  assert(callback && "expected valid legality callback");
3276  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3277 }
3278 
3279 auto ConversionTarget::getOpInfo(OperationName op) const
3280  -> std::optional<LegalizationInfo> {
3281  // Check for info for this specific operation.
3282  const auto *it = legalOperations.find(op);
3283  if (it != legalOperations.end())
3284  return it->second;
3285  // Check for info for the parent dialect.
3286  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3287  if (dialectIt != legalDialects.end()) {
3288  DynamicLegalityCallbackFn callback;
3289  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3290  if (dialectFn != dialectLegalityFns.end())
3291  callback = dialectFn->second;
3292  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3293  callback};
3294  }
3295  // Otherwise, check if we mark unknown operations as dynamic.
3296  if (unknownLegalityFn)
3297  return LegalizationInfo{LegalizationAction::Dynamic,
3298  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3299  return std::nullopt;
3300 }
3301 
3302 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3303 //===----------------------------------------------------------------------===//
3304 // PDL Configuration
3305 //===----------------------------------------------------------------------===//
3306 
3308  auto &rewriterImpl =
3309  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3310  rewriterImpl.currentTypeConverter = getTypeConverter();
3311 }
3312 
3314  auto &rewriterImpl =
3315  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3316  rewriterImpl.currentTypeConverter = nullptr;
3317 }
3318 
3319 /// Remap the given value using the rewriter and the type converter in the
3320 /// provided config.
3321 static FailureOr<SmallVector<Value>>
3323  SmallVector<Value> mappedValues;
3324  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3325  return failure();
3326  return std::move(mappedValues);
3327 }
3328 
3330  patterns.getPDLPatterns().registerRewriteFunction(
3331  "convertValue",
3332  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3333  auto results = pdllConvertValues(
3334  static_cast<ConversionPatternRewriter &>(rewriter), value);
3335  if (failed(results))
3336  return failure();
3337  return results->front();
3338  });
3339  patterns.getPDLPatterns().registerRewriteFunction(
3340  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3341  return pdllConvertValues(
3342  static_cast<ConversionPatternRewriter &>(rewriter), values);
3343  });
3344  patterns.getPDLPatterns().registerRewriteFunction(
3345  "convertType",
3346  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3347  auto &rewriterImpl =
3348  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3349  if (const TypeConverter *converter =
3350  rewriterImpl.currentTypeConverter) {
3351  if (Type newType = converter->convertType(type))
3352  return newType;
3353  return failure();
3354  }
3355  return type;
3356  });
3357  patterns.getPDLPatterns().registerRewriteFunction(
3358  "convertTypes",
3359  [](PatternRewriter &rewriter,
3360  TypeRange types) -> FailureOr<SmallVector<Type>> {
3361  auto &rewriterImpl =
3362  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3363  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3364  if (!converter)
3365  return SmallVector<Type>(types);
3366 
3367  SmallVector<Type> remappedTypes;
3368  if (failed(converter->convertTypes(types, remappedTypes)))
3369  return failure();
3370  return std::move(remappedTypes);
3371  });
3372 }
3373 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3374 
3375 //===----------------------------------------------------------------------===//
3376 // Op Conversion Entry Points
3377 //===----------------------------------------------------------------------===//
3378 
3379 //===----------------------------------------------------------------------===//
3380 // Partial Conversion
3381 //===----------------------------------------------------------------------===//
3382 
3384  ArrayRef<Operation *> ops, const ConversionTarget &target,
3386  OperationConverter opConverter(target, patterns, config,
3387  OpConversionMode::Partial);
3388  return opConverter.convertOperations(ops);
3389 }
3390 LogicalResult
3394  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3395 }
3396 
3397 //===----------------------------------------------------------------------===//
3398 // Full Conversion
3399 //===----------------------------------------------------------------------===//
3400 
3402  const ConversionTarget &target,
3405  OperationConverter opConverter(target, patterns, config,
3406  OpConversionMode::Full);
3407  return opConverter.convertOperations(ops);
3408 }
3410  const ConversionTarget &target,
3413  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3414 }
3415 
3416 //===----------------------------------------------------------------------===//
3417 // Analysis Conversion
3418 //===----------------------------------------------------------------------===//
3419 
3420 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3421 /// op is a top-level module op (which is expected to be isolated from above),
3422 /// return that op.
3424  // Check if there is a top-level operation within `ops`. If so, return that
3425  // op.
3426  for (Operation *op : ops) {
3427  if (!op->getParentOp()) {
3428 #ifndef NDEBUG
3429  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3430  "expected top-level op to be isolated from above");
3431  for (Operation *other : ops)
3432  assert(op->isAncestor(other) &&
3433  "expected ops to have a common ancestor");
3434 #endif // NDEBUG
3435  return op;
3436  }
3437  }
3438 
3439  // No top-level op. Find a common ancestor.
3440  Operation *commonAncestor =
3441  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3442  for (Operation *op : ops.drop_front()) {
3443  while (!commonAncestor->isProperAncestor(op)) {
3444  commonAncestor =
3446  assert(commonAncestor &&
3447  "expected to find a common isolated from above ancestor");
3448  }
3449  }
3450 
3451  return commonAncestor;
3452 }
3453 
3457 #ifndef NDEBUG
3458  if (config.legalizableOps)
3459  assert(config.legalizableOps->empty() && "expected empty set");
3460 #endif // NDEBUG
3461 
3462  // Clone closted common ancestor that is isolated from above.
3463  Operation *commonAncestor = findCommonAncestor(ops);
3464  IRMapping mapping;
3465  Operation *clonedAncestor = commonAncestor->clone(mapping);
3466  // Compute inverse IR mapping.
3467  DenseMap<Operation *, Operation *> inverseOperationMap;
3468  for (auto &it : mapping.getOperationMap())
3469  inverseOperationMap[it.second] = it.first;
3470 
3471  // Convert the cloned operations. The original IR will remain unchanged.
3472  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3473  ops, [&](Operation *op) { return mapping.lookup(op); });
3474  OperationConverter opConverter(target, patterns, config,
3475  OpConversionMode::Analysis);
3476  LogicalResult status = opConverter.convertOperations(opsToConvert);
3477 
3478  // Remap `legalizableOps`, so that they point to the original ops and not the
3479  // cloned ops.
3480  if (config.legalizableOps) {
3481  DenseSet<Operation *> originalLegalizableOps;
3482  for (Operation *op : *config.legalizableOps)
3483  originalLegalizableOps.insert(inverseOperationMap[op]);
3484  *config.legalizableOps = std::move(originalLegalizableOps);
3485  }
3486 
3487  // Erase the cloned IR.
3488  clonedAncestor->erase();
3489  return status;
3490 }
3491 
3492 LogicalResult
3496  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3497 }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:304
Location getLoc() const
Return the location for this argument.
Definition: Value.h:310
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:68
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:96
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void dropAllUses()
Drop all uses of this object from their respective owners.
Definition: UseDefLists.h:202
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:730
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:324
Block::iterator getPoint() const
Definition: Builders.h:337
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:334
Block * getBlock() const
Definition: Builders.h:336
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:317
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:469
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:243
This is a value defined by a result of an operation.
Definition: Value.h:433
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getTypes() const
Definition: ValueRange.cpp:28
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:296
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:719
operand_iterator operand_begin()
Definition: Operation.h:374
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_iterator operand_end()
Definition: Operation.h:375
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:607
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
result_range getResults()
Definition: Operation.h:415
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:896
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
succ_iterator successor_end()
Definition: Operation.h:702
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
succ_iterator successor_begin()
Definition: Operation.h:701
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:204
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a range of values that remaps an existing signature in...
A rewriter that keeps track of erased ops and blocks.
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
SingleEraseRewriter eraseRewriter
A rewriter that keeps track of ops/block that were already erased and skips duplicate op/block erasur...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp=nullptr)
Build an unresolved materialization operation given a range of output types and a list of input opera...
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
void notifyOpReplaced(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Notifies that an op is about to be replaced with the given values.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.