MLIR  20.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Iterators.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::detail;
29 
30 #define DEBUG_TYPE "dialect-conversion"
31 
32 /// A utility function to log a successful result for the given reason.
33 template <typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
35  LLVM_DEBUG({
36  os.unindent();
37  os.startLine() << "} -> SUCCESS";
38  if (!fmt.empty())
39  os.getOStream() << " : "
40  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41  os.getOStream() << "\n";
42  });
43 }
44 
45 /// A utility function to log a failure result for the given reason.
46 template <typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
48  LLVM_DEBUG({
49  os.unindent();
50  os.startLine() << "} -> FAILURE : "
51  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
52  << "\n";
53  });
54 }
55 
56 /// Helper function that computes an insertion point where the given value is
57 /// defined and can be used without a dominance violation.
59  Block *insertBlock = value.getParentBlock();
60  Block::iterator insertPt = insertBlock->begin();
61  if (OpResult inputRes = dyn_cast<OpResult>(value))
62  insertPt = ++inputRes.getOwner()->getIterator();
63  return OpBuilder::InsertPoint(insertBlock, insertPt);
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // ConversionValueMapping
68 //===----------------------------------------------------------------------===//
69 
70 namespace {
71 /// This class wraps a IRMapping to provide recursive lookup
72 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
73 struct ConversionValueMapping {
74  /// Lookup the most recently mapped value with the desired type in the
75  /// mapping.
76  ///
77  /// Special cases:
78  /// - If the desired type is "null", simply return the most recently mapped
79  /// value.
80  /// - If there is no mapping to the desired type, also return the most
81  /// recently mapped value.
82  /// - If there is no mapping for the given value at all, return the given
83  /// value.
84  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
85 
86  /// Lookup a mapped value within the map, or return null if a mapping does not
87  /// exist. If a mapping exists, this follows the same behavior of
88  /// `lookupOrDefault`.
89  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
90 
91  /// Map a value to the one provided.
92  void map(Value oldVal, Value newVal) {
93  LLVM_DEBUG({
94  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
95  assert(it != oldVal && "inserting cyclic mapping");
96  });
97  mapping.map(oldVal, newVal);
98  }
99 
100  /// Try to map a value to the one provided. Returns false if a transitive
101  /// mapping from the new value to the old value already exists, true if the
102  /// map was updated.
103  bool tryMap(Value oldVal, Value newVal);
104 
105  /// Drop the last mapping for the given value.
106  void erase(Value value) { mapping.erase(value); }
107 
108  /// Returns the inverse raw value mapping (without recursive query support).
109  DenseMap<Value, SmallVector<Value>> getInverse() const {
111  for (auto &it : mapping.getValueMap())
112  inverse[it.second].push_back(it.first);
113  return inverse;
114  }
115 
116 private:
117  /// Current value mappings.
118  IRMapping mapping;
119 };
120 } // namespace
121 
122 Value ConversionValueMapping::lookupOrDefault(Value from,
123  Type desiredType) const {
124  // Try to find the deepest value that has the desired type. If there is no
125  // such value, simply return the deepest value.
126  Value desiredValue;
127  do {
128  if (!desiredType || from.getType() == desiredType)
129  desiredValue = from;
130 
131  Value mappedValue = mapping.lookupOrNull(from);
132  if (!mappedValue)
133  break;
134  from = mappedValue;
135  } while (true);
136 
137  // If the desired value was found use it, otherwise default to the leaf value.
138  return desiredValue ? desiredValue : from;
139 }
140 
141 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
142  Value result = lookupOrDefault(from, desiredType);
143  if (result == from || (desiredType && result.getType() != desiredType))
144  return nullptr;
145  return result;
146 }
147 
148 bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
149  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
150  if (it == oldVal)
151  return false;
152  map(oldVal, newVal);
153  return true;
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // Rewriter and Translation State
158 //===----------------------------------------------------------------------===//
159 namespace {
160 /// This class contains a snapshot of the current conversion rewriter state.
161 /// This is useful when saving and undoing a set of rewrites.
162 struct RewriterState {
163  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
164  unsigned numReplacedOps)
165  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
166  numReplacedOps(numReplacedOps) {}
167 
168  /// The current number of rewrites performed.
169  unsigned numRewrites;
170 
171  /// The current number of ignored operations.
172  unsigned numIgnoredOperations;
173 
174  /// The current number of replaced ops that are scheduled for erasure.
175  unsigned numReplacedOps;
176 };
177 
178 //===----------------------------------------------------------------------===//
179 // IR rewrites
180 //===----------------------------------------------------------------------===//
181 
182 /// An IR rewrite that can be committed (upon success) or rolled back (upon
183 /// failure).
184 ///
185 /// The dialect conversion keeps track of IR modifications (requested by the
186 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
187 /// are directly applied to the IR as the rewriter API is used, some are applied
188 /// partially, and some are delayed until the `IRRewrite` objects are committed.
189 class IRRewrite {
190 public:
191  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
192  /// Enum values are ordered, so that they can be used in `classof`: first all
193  /// block rewrites, then all operation rewrites.
194  enum class Kind {
195  // Block rewrites
196  CreateBlock,
197  EraseBlock,
198  InlineBlock,
199  MoveBlock,
200  BlockTypeConversion,
201  ReplaceBlockArg,
202  // Operation rewrites
203  MoveOperation,
204  ModifyOperation,
205  ReplaceOperation,
206  CreateOperation,
207  UnresolvedMaterialization
208  };
209 
210  virtual ~IRRewrite() = default;
211 
212  /// Roll back the rewrite. Operations may be erased during rollback.
213  virtual void rollback() = 0;
214 
215  /// Commit the rewrite. At this point, it is certain that the dialect
216  /// conversion will succeed. All IR modifications, except for operation/block
217  /// erasure, must be performed through the given rewriter.
218  ///
219  /// Instead of erasing operations/blocks, they should merely be unlinked
220  /// commit phase and finally be erased during the cleanup phase. This is
221  /// because internal dialect conversion state (such as `mapping`) may still
222  /// be using them.
223  ///
224  /// Any IR modification that was already performed before the commit phase
225  /// (e.g., insertion of an op) must be communicated to the listener that may
226  /// be attached to the given rewriter.
227  virtual void commit(RewriterBase &rewriter) {}
228 
229  /// Cleanup operations/blocks. Cleanup is called after commit.
230  virtual void cleanup(RewriterBase &rewriter) {}
231 
232  Kind getKind() const { return kind; }
233 
234  static bool classof(const IRRewrite *rewrite) { return true; }
235 
236 protected:
237  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
238  : kind(kind), rewriterImpl(rewriterImpl) {}
239 
240  const ConversionConfig &getConfig() const;
241 
242  const Kind kind;
243  ConversionPatternRewriterImpl &rewriterImpl;
244 };
245 
246 /// A block rewrite.
247 class BlockRewrite : public IRRewrite {
248 public:
249  /// Return the block that this rewrite operates on.
250  Block *getBlock() const { return block; }
251 
252  static bool classof(const IRRewrite *rewrite) {
253  return rewrite->getKind() >= Kind::CreateBlock &&
254  rewrite->getKind() <= Kind::ReplaceBlockArg;
255  }
256 
257 protected:
258  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
259  Block *block)
260  : IRRewrite(kind, rewriterImpl), block(block) {}
261 
262  // The block that this rewrite operates on.
263  Block *block;
264 };
265 
266 /// Creation of a block. Block creations are immediately reflected in the IR.
267 /// There is no extra work to commit the rewrite. During rollback, the newly
268 /// created block is erased.
269 class CreateBlockRewrite : public BlockRewrite {
270 public:
271  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
272  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
273 
274  static bool classof(const IRRewrite *rewrite) {
275  return rewrite->getKind() == Kind::CreateBlock;
276  }
277 
278  void commit(RewriterBase &rewriter) override {
279  // The block was already created and inserted. Just inform the listener.
280  if (auto *listener = rewriter.getListener())
281  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
282  }
283 
284  void rollback() override {
285  // Unlink all of the operations within this block, they will be deleted
286  // separately.
287  auto &blockOps = block->getOperations();
288  while (!blockOps.empty())
289  blockOps.remove(blockOps.begin());
290  block->dropAllUses();
291  if (block->getParent())
292  block->erase();
293  else
294  delete block;
295  }
296 };
297 
298 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
299 /// blocks are immediately unlinked, but only erased during cleanup. This makes
300 /// it easier to rollback a block erasure: the block is simply inserted into its
301 /// original location.
302 class EraseBlockRewrite : public BlockRewrite {
303 public:
304  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
305  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
306  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
307 
308  static bool classof(const IRRewrite *rewrite) {
309  return rewrite->getKind() == Kind::EraseBlock;
310  }
311 
312  ~EraseBlockRewrite() override {
313  assert(!block &&
314  "rewrite was neither rolled back nor committed/cleaned up");
315  }
316 
317  void rollback() override {
318  // The block (owned by this rewrite) was not actually erased yet. It was
319  // just unlinked. Put it back into its original position.
320  assert(block && "expected block");
321  auto &blockList = region->getBlocks();
322  Region::iterator before = insertBeforeBlock
323  ? Region::iterator(insertBeforeBlock)
324  : blockList.end();
325  blockList.insert(before, block);
326  block = nullptr;
327  }
328 
329  void commit(RewriterBase &rewriter) override {
330  // Erase the block.
331  assert(block && "expected block");
332  assert(block->empty() && "expected empty block");
333 
334  // Notify the listener that the block is about to be erased.
335  if (auto *listener =
336  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
337  listener->notifyBlockErased(block);
338  }
339 
340  void cleanup(RewriterBase &rewriter) override {
341  // Erase the block.
342  block->dropAllDefinedValueUses();
343  delete block;
344  block = nullptr;
345  }
346 
347 private:
348  // The region in which this block was previously contained.
349  Region *region;
350 
351  // The original successor of this block before it was unlinked. "nullptr" if
352  // this block was the only block in the region.
353  Block *insertBeforeBlock;
354 };
355 
356 /// Inlining of a block. This rewrite is immediately reflected in the IR.
357 /// Note: This rewrite represents only the inlining of the operations. The
358 /// erasure of the inlined block is a separate rewrite.
359 class InlineBlockRewrite : public BlockRewrite {
360 public:
361  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
362  Block *sourceBlock, Block::iterator before)
363  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
364  sourceBlock(sourceBlock),
365  firstInlinedInst(sourceBlock->empty() ? nullptr
366  : &sourceBlock->front()),
367  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
368  // If a listener is attached to the dialect conversion, ops must be moved
369  // one-by-one. When they are moved in bulk, notifications cannot be sent
370  // because the ops that used to be in the source block at the time of the
371  // inlining (before the "commit" phase) are unknown at the time when
372  // notifications are sent (which is during the "commit" phase).
373  assert(!getConfig().listener &&
374  "InlineBlockRewrite not supported if listener is attached");
375  }
376 
377  static bool classof(const IRRewrite *rewrite) {
378  return rewrite->getKind() == Kind::InlineBlock;
379  }
380 
381  void rollback() override {
382  // Put the operations from the destination block (owned by the rewrite)
383  // back into the source block.
384  if (firstInlinedInst) {
385  assert(lastInlinedInst && "expected operation");
386  sourceBlock->getOperations().splice(sourceBlock->begin(),
387  block->getOperations(),
388  Block::iterator(firstInlinedInst),
389  ++Block::iterator(lastInlinedInst));
390  }
391  }
392 
393 private:
394  // The block that originally contained the operations.
395  Block *sourceBlock;
396 
397  // The first inlined operation.
398  Operation *firstInlinedInst;
399 
400  // The last inlined operation.
401  Operation *lastInlinedInst;
402 };
403 
404 /// Moving of a block. This rewrite is immediately reflected in the IR.
405 class MoveBlockRewrite : public BlockRewrite {
406 public:
407  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
408  Region *region, Block *insertBeforeBlock)
409  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
410  insertBeforeBlock(insertBeforeBlock) {}
411 
412  static bool classof(const IRRewrite *rewrite) {
413  return rewrite->getKind() == Kind::MoveBlock;
414  }
415 
416  void commit(RewriterBase &rewriter) override {
417  // The block was already moved. Just inform the listener.
418  if (auto *listener = rewriter.getListener()) {
419  // Note: `previousIt` cannot be passed because this is a delayed
420  // notification and iterators into past IR state cannot be represented.
421  listener->notifyBlockInserted(block, /*previous=*/region,
422  /*previousIt=*/{});
423  }
424  }
425 
426  void rollback() override {
427  // Move the block back to its original position.
428  Region::iterator before =
429  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
430  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
431  }
432 
433 private:
434  // The region in which this block was previously contained.
435  Region *region;
436 
437  // The original successor of this block before it was moved. "nullptr" if
438  // this block was the only block in the region.
439  Block *insertBeforeBlock;
440 };
441 
442 /// Block type conversion. This rewrite is partially reflected in the IR.
443 class BlockTypeConversionRewrite : public BlockRewrite {
444 public:
445  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
446  Block *block, Block *origBlock,
447  const TypeConverter *converter)
448  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
449  origBlock(origBlock), converter(converter) {}
450 
451  static bool classof(const IRRewrite *rewrite) {
452  return rewrite->getKind() == Kind::BlockTypeConversion;
453  }
454 
455  Block *getOrigBlock() const { return origBlock; }
456 
457  const TypeConverter *getConverter() const { return converter; }
458 
459  void commit(RewriterBase &rewriter) override;
460 
461  void rollback() override;
462 
463 private:
464  /// The original block that was requested to have its signature converted.
465  Block *origBlock;
466 
467  /// The type converter used to convert the arguments.
468  const TypeConverter *converter;
469 };
470 
471 /// Replacing a block argument. This rewrite is not immediately reflected in the
472 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
473 /// until the rewrite is committed.
474 class ReplaceBlockArgRewrite : public BlockRewrite {
475 public:
476  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
477  Block *block, BlockArgument arg)
478  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
479 
480  static bool classof(const IRRewrite *rewrite) {
481  return rewrite->getKind() == Kind::ReplaceBlockArg;
482  }
483 
484  void commit(RewriterBase &rewriter) override;
485 
486  void rollback() override;
487 
488 private:
489  BlockArgument arg;
490 };
491 
492 /// An operation rewrite.
493 class OperationRewrite : public IRRewrite {
494 public:
495  /// Return the operation that this rewrite operates on.
496  Operation *getOperation() const { return op; }
497 
498  static bool classof(const IRRewrite *rewrite) {
499  return rewrite->getKind() >= Kind::MoveOperation &&
500  rewrite->getKind() <= Kind::UnresolvedMaterialization;
501  }
502 
503 protected:
504  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
505  Operation *op)
506  : IRRewrite(kind, rewriterImpl), op(op) {}
507 
508  // The operation that this rewrite operates on.
509  Operation *op;
510 };
511 
512 /// Moving of an operation. This rewrite is immediately reflected in the IR.
513 class MoveOperationRewrite : public OperationRewrite {
514 public:
515  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
516  Operation *op, Block *block, Operation *insertBeforeOp)
517  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
518  insertBeforeOp(insertBeforeOp) {}
519 
520  static bool classof(const IRRewrite *rewrite) {
521  return rewrite->getKind() == Kind::MoveOperation;
522  }
523 
524  void commit(RewriterBase &rewriter) override {
525  // The operation was already moved. Just inform the listener.
526  if (auto *listener = rewriter.getListener()) {
527  // Note: `previousIt` cannot be passed because this is a delayed
528  // notification and iterators into past IR state cannot be represented.
529  listener->notifyOperationInserted(
530  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
531  /*insertPt=*/{}));
532  }
533  }
534 
535  void rollback() override {
536  // Move the operation back to its original position.
537  Block::iterator before =
538  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
539  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
540  }
541 
542 private:
543  // The block in which this operation was previously contained.
544  Block *block;
545 
546  // The original successor of this operation before it was moved. "nullptr"
547  // if this operation was the only operation in the region.
548  Operation *insertBeforeOp;
549 };
550 
551 /// In-place modification of an op. This rewrite is immediately reflected in
552 /// the IR. The previous state of the operation is stored in this object.
553 class ModifyOperationRewrite : public OperationRewrite {
554 public:
555  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
556  Operation *op)
557  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
558  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
559  operands(op->operand_begin(), op->operand_end()),
560  successors(op->successor_begin(), op->successor_end()) {
561  if (OpaqueProperties prop = op->getPropertiesStorage()) {
562  // Make a copy of the properties.
563  propertiesStorage = operator new(op->getPropertiesStorageSize());
564  OpaqueProperties propCopy(propertiesStorage);
565  name.initOpProperties(propCopy, /*init=*/prop);
566  }
567  }
568 
569  static bool classof(const IRRewrite *rewrite) {
570  return rewrite->getKind() == Kind::ModifyOperation;
571  }
572 
573  ~ModifyOperationRewrite() override {
574  assert(!propertiesStorage &&
575  "rewrite was neither committed nor rolled back");
576  }
577 
578  void commit(RewriterBase &rewriter) override {
579  // Notify the listener that the operation was modified in-place.
580  if (auto *listener =
581  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
582  listener->notifyOperationModified(op);
583 
584  if (propertiesStorage) {
585  OpaqueProperties propCopy(propertiesStorage);
586  // Note: The operation may have been erased in the mean time, so
587  // OperationName must be stored in this object.
588  name.destroyOpProperties(propCopy);
589  operator delete(propertiesStorage);
590  propertiesStorage = nullptr;
591  }
592  }
593 
594  void rollback() override {
595  op->setLoc(loc);
596  op->setAttrs(attrs);
597  op->setOperands(operands);
598  for (const auto &it : llvm::enumerate(successors))
599  op->setSuccessor(it.value(), it.index());
600  if (propertiesStorage) {
601  OpaqueProperties propCopy(propertiesStorage);
602  op->copyProperties(propCopy);
603  name.destroyOpProperties(propCopy);
604  operator delete(propertiesStorage);
605  propertiesStorage = nullptr;
606  }
607  }
608 
609 private:
610  OperationName name;
611  LocationAttr loc;
612  DictionaryAttr attrs;
613  SmallVector<Value, 8> operands;
614  SmallVector<Block *, 2> successors;
615  void *propertiesStorage = nullptr;
616 };
617 
618 /// Replacing an operation. Erasing an operation is treated as a special case
619 /// with "null" replacements. This rewrite is not immediately reflected in the
620 /// IR. An internal IR mapping is updated, but values are not replaced and the
621 /// original op is not erased until the rewrite is committed.
622 class ReplaceOperationRewrite : public OperationRewrite {
623 public:
624  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
625  Operation *op, const TypeConverter *converter)
626  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
627  converter(converter) {}
628 
629  static bool classof(const IRRewrite *rewrite) {
630  return rewrite->getKind() == Kind::ReplaceOperation;
631  }
632 
633  void commit(RewriterBase &rewriter) override;
634 
635  void rollback() override;
636 
637  void cleanup(RewriterBase &rewriter) override;
638 
639  const TypeConverter *getConverter() const { return converter; }
640 
641 private:
642  /// An optional type converter that can be used to materialize conversions
643  /// between the new and old values if necessary.
644  const TypeConverter *converter;
645 };
646 
647 class CreateOperationRewrite : public OperationRewrite {
648 public:
649  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
650  Operation *op)
651  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
652 
653  static bool classof(const IRRewrite *rewrite) {
654  return rewrite->getKind() == Kind::CreateOperation;
655  }
656 
657  void commit(RewriterBase &rewriter) override {
658  // The operation was already created and inserted. Just inform the listener.
659  if (auto *listener = rewriter.getListener())
660  listener->notifyOperationInserted(op, /*previous=*/{});
661  }
662 
663  void rollback() override;
664 };
665 
666 /// The type of materialization.
667 enum MaterializationKind {
668  /// This materialization materializes a conversion for an illegal block
669  /// argument type, to the original one.
670  Argument,
671 
672  /// This materialization materializes a conversion from an illegal type to a
673  /// legal one.
674  Target,
675 
676  /// This materialization materializes a conversion from a legal type back to
677  /// an illegal one.
678  Source
679 };
680 
681 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
682 /// op. Unresolved materializations are erased at the end of the dialect
683 /// conversion.
684 class UnresolvedMaterializationRewrite : public OperationRewrite {
685 public:
686  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
687  UnrealizedConversionCastOp op,
688  const TypeConverter *converter,
689  MaterializationKind kind, Type originalType);
690 
691  static bool classof(const IRRewrite *rewrite) {
692  return rewrite->getKind() == Kind::UnresolvedMaterialization;
693  }
694 
695  void rollback() override;
696 
697  UnrealizedConversionCastOp getOperation() const {
698  return cast<UnrealizedConversionCastOp>(op);
699  }
700 
701  /// Return the type converter of this materialization (which may be null).
702  const TypeConverter *getConverter() const {
703  return converterAndKind.getPointer();
704  }
705 
706  /// Return the kind of this materialization.
707  MaterializationKind getMaterializationKind() const {
708  return converterAndKind.getInt();
709  }
710 
711  /// Return the original type of the SSA value.
712  Type getOriginalType() const { return originalType; }
713 
714 private:
715  /// The corresponding type converter to use when resolving this
716  /// materialization, and the kind of this materialization.
717  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
718  converterAndKind;
719 
720  /// The original type of the SSA value. Only used for target
721  /// materializations.
722  Type originalType;
723 };
724 } // namespace
725 
726 /// Return "true" if there is an operation rewrite that matches the specified
727 /// rewrite type and operation among the given rewrites.
728 template <typename RewriteTy, typename R>
729 static bool hasRewrite(R &&rewrites, Operation *op) {
730  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
731  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
732  return rewriteTy && rewriteTy->getOperation() == op;
733  });
734 }
735 
736 //===----------------------------------------------------------------------===//
737 // ConversionPatternRewriterImpl
738 //===----------------------------------------------------------------------===//
739 namespace mlir {
740 namespace detail {
743  const ConversionConfig &config)
744  : context(ctx), eraseRewriter(ctx), config(config) {}
745 
746  //===--------------------------------------------------------------------===//
747  // State Management
748  //===--------------------------------------------------------------------===//
749 
750  /// Return the current state of the rewriter.
751  RewriterState getCurrentState();
752 
753  /// Apply all requested operation rewrites. This method is invoked when the
754  /// conversion process succeeds.
755  void applyRewrites();
756 
757  /// Reset the state of the rewriter to a previously saved point.
758  void resetState(RewriterState state);
759 
760  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
761  /// failure.
762  template <typename RewriteTy, typename... Args>
763  void appendRewrite(Args &&...args) {
764  rewrites.push_back(
765  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
766  }
767 
768  /// Undo the rewrites (motions, splits) one by one in reverse order until
769  /// "numRewritesToKeep" rewrites remains.
770  void undoRewrites(unsigned numRewritesToKeep = 0);
771 
772  /// Remap the given values to those with potentially different types. Returns
773  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
774  /// is the tag used when describing a value within a diagnostic, e.g.
775  /// "operand".
776  LogicalResult remapValues(StringRef valueDiagTag,
777  std::optional<Location> inputLoc,
778  PatternRewriter &rewriter, ValueRange values,
779  SmallVectorImpl<Value> &remapped);
780 
781  /// Return "true" if the given operation is ignored, and does not need to be
782  /// converted.
783  bool isOpIgnored(Operation *op) const;
784 
785  /// Return "true" if the given operation was replaced or erased.
786  bool wasOpReplaced(Operation *op) const;
787 
788  //===--------------------------------------------------------------------===//
789  // Type Conversion
790  //===--------------------------------------------------------------------===//
791 
792  /// Convert the types of block arguments within the given region.
793  FailureOr<Block *>
794  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
795  const TypeConverter &converter,
796  TypeConverter::SignatureConversion *entryConversion);
797 
798  /// Apply the given signature conversion on the given block. The new block
799  /// containing the updated signature is returned. If no conversions were
800  /// necessary, e.g. if the block has no arguments, `block` is returned.
801  /// `converter` is used to generate any necessary cast operations that
802  /// translate between the origin argument types and those specified in the
803  /// signature conversion.
804  Block *applySignatureConversion(
805  ConversionPatternRewriter &rewriter, Block *block,
806  const TypeConverter *converter,
807  TypeConverter::SignatureConversion &signatureConversion);
808 
809  //===--------------------------------------------------------------------===//
810  // Materializations
811  //===--------------------------------------------------------------------===//
812 
813  /// Build an unresolved materialization operation given an output type and set
814  /// of input operands.
815  Value buildUnresolvedMaterialization(MaterializationKind kind,
817  ValueRange inputs, Type outputType,
818  Type originalType,
819  const TypeConverter *converter);
820 
821  //===--------------------------------------------------------------------===//
822  // Rewriter Notification Hooks
823  //===--------------------------------------------------------------------===//
824 
825  //// Notifies that an op was inserted.
826  void notifyOperationInserted(Operation *op,
827  OpBuilder::InsertPoint previous) override;
828 
829  /// Notifies that an op is about to be replaced with the given values.
830  void notifyOpReplaced(Operation *op, ValueRange newValues);
831 
832  /// Notifies that a block is about to be erased.
833  void notifyBlockIsBeingErased(Block *block);
834 
835  /// Notifies that a block was inserted.
836  void notifyBlockInserted(Block *block, Region *previous,
837  Region::iterator previousIt) override;
838 
839  /// Notifies that a block is being inlined into another block.
840  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
841  Block::iterator before);
842 
843  /// Notifies that a pattern match failed for the given reason.
844  void
845  notifyMatchFailure(Location loc,
846  function_ref<void(Diagnostic &)> reasonCallback) override;
847 
848  //===--------------------------------------------------------------------===//
849  // IR Erasure
850  //===--------------------------------------------------------------------===//
851 
852  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
853  /// operation or block is erased multiple times. This rewriter assumes that
854  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
856  public:
858  : RewriterBase(context, /*listener=*/this) {}
859 
860  /// Erase the given op (unless it was already erased).
861  void eraseOp(Operation *op) override {
862  if (wasErased(op))
863  return;
864  op->dropAllUses();
866  }
867 
868  /// Erase the given block (unless it was already erased).
869  void eraseBlock(Block *block) override {
870  if (wasErased(block))
871  return;
872  assert(block->empty() && "expected empty block");
873  block->dropAllDefinedValueUses();
875  }
876 
877  bool wasErased(void *ptr) const { return erased.contains(ptr); }
878 
879  void notifyOperationErased(Operation *op) override { erased.insert(op); }
880 
881  void notifyBlockErased(Block *block) override { erased.insert(block); }
882 
883  private:
884  /// Pointers to all erased operations and blocks.
885  DenseSet<void *> erased;
886  };
887 
888  //===--------------------------------------------------------------------===//
889  // State
890  //===--------------------------------------------------------------------===//
891 
892  /// MLIR context.
894 
895  /// A rewriter that keeps track of ops/block that were already erased and
896  /// skips duplicate op/block erasures. This rewriter is used during the
897  /// "cleanup" phase.
899 
900  // Mapping between replaced values that differ in type. This happens when
901  // replacing a value with one of a different type.
902  ConversionValueMapping mapping;
903 
904  /// Ordered list of block operations (creations, splits, motions).
906 
907  /// A set of operations that should no longer be considered for legalization.
908  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
909  /// tracked separately.
911 
912  /// A set of operations that were replaced/erased. Such ops are not erased
913  /// immediately but only when the dialect conversion succeeds. In the mean
914  /// time, they should no longer be considered for legalization and any attempt
915  /// to modify/access them is invalid rewriter API usage.
917 
918  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
919  /// to the corresponding rewrite objects.
922 
923  /// The current type converter, or nullptr if no type converter is currently
924  /// active.
925  const TypeConverter *currentTypeConverter = nullptr;
926 
927  /// A mapping of regions to type converters that should be used when
928  /// converting the arguments of blocks within that region.
930 
931  /// Dialect conversion configuration.
933 
934 #ifndef NDEBUG
935  /// A set of operations that have pending updates. This tracking isn't
936  /// strictly necessary, and is thus only active during debug builds for extra
937  /// verification.
939 
940  /// A logger used to emit diagnostics during the conversion process.
941  llvm::ScopedPrinter logger{llvm::dbgs()};
942 #endif
943 };
944 } // namespace detail
945 } // namespace mlir
946 
947 const ConversionConfig &IRRewrite::getConfig() const {
948  return rewriterImpl.config;
949 }
950 
951 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
952  // Inform the listener about all IR modifications that have already taken
953  // place: References to the original block have been replaced with the new
954  // block.
955  if (auto *listener =
956  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
957  for (Operation *op : block->getUsers())
958  listener->notifyOperationModified(op);
959 }
960 
961 void BlockTypeConversionRewrite::rollback() {
962  block->replaceAllUsesWith(origBlock);
963 }
964 
965 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
966  Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
967  if (!repl)
968  return;
969 
970  if (isa<BlockArgument>(repl)) {
971  rewriter.replaceAllUsesWith(arg, repl);
972  return;
973  }
974 
975  // If the replacement value is an operation, we check to make sure that we
976  // don't replace uses that are within the parent operation of the
977  // replacement value.
978  Operation *replOp = cast<OpResult>(repl).getOwner();
979  Block *replBlock = replOp->getBlock();
980  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
981  Operation *user = operand.getOwner();
982  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
983  });
984 }
985 
986 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
987 
988 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
989  auto *listener =
990  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
991 
992  // Compute replacement values.
993  SmallVector<Value> replacements =
994  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
995  return rewriterImpl.mapping.lookupOrNull(result, result.getType());
996  });
997 
998  // Notify the listener that the operation is about to be replaced.
999  if (listener)
1000  listener->notifyOperationReplaced(op, replacements);
1001 
1002  // Replace all uses with the new values.
1003  for (auto [result, newValue] :
1004  llvm::zip_equal(op->getResults(), replacements))
1005  if (newValue)
1006  rewriter.replaceAllUsesWith(result, newValue);
1007 
1008  // The original op will be erased, so remove it from the set of unlegalized
1009  // ops.
1010  if (getConfig().unlegalizedOps)
1011  getConfig().unlegalizedOps->erase(op);
1012 
1013  // Notify the listener that the operation (and its nested operations) was
1014  // erased.
1015  if (listener) {
1017  [&](Operation *op) { listener->notifyOperationErased(op); });
1018  }
1019 
1020  // Do not erase the operation yet. It may still be referenced in `mapping`.
1021  // Just unlink it for now and erase it during cleanup.
1022  op->getBlock()->getOperations().remove(op);
1023 }
1024 
1025 void ReplaceOperationRewrite::rollback() {
1026  for (auto result : op->getResults())
1027  rewriterImpl.mapping.erase(result);
1028 }
1029 
1030 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1031  rewriter.eraseOp(op);
1032 }
1033 
1034 void CreateOperationRewrite::rollback() {
1035  for (Region &region : op->getRegions()) {
1036  while (!region.getBlocks().empty())
1037  region.getBlocks().remove(region.getBlocks().begin());
1038  }
1039  op->dropAllUses();
1040  op->erase();
1041 }
1042 
1043 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1044  ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1045  const TypeConverter *converter, MaterializationKind kind, Type originalType)
1046  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1047  converterAndKind(converter, kind), originalType(originalType) {
1048  assert((!originalType || kind == MaterializationKind::Target) &&
1049  "original type is valid only for target materializations");
1050  rewriterImpl.unresolvedMaterializations[op] = this;
1051 }
1052 
1053 void UnresolvedMaterializationRewrite::rollback() {
1054  if (getMaterializationKind() == MaterializationKind::Target) {
1055  for (Value input : op->getOperands())
1056  rewriterImpl.mapping.erase(input);
1057  }
1058  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1059  op->erase();
1060 }
1061 
1063  // Commit all rewrites.
1064  IRRewriter rewriter(context, config.listener);
1065  for (auto &rewrite : rewrites)
1066  rewrite->commit(rewriter);
1067 
1068  // Clean up all rewrites.
1069  for (auto &rewrite : rewrites)
1070  rewrite->cleanup(eraseRewriter);
1071 }
1072 
1073 //===----------------------------------------------------------------------===//
1074 // State Management
1075 
1077  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1078 }
1079 
1081  // Undo any rewrites.
1082  undoRewrites(state.numRewrites);
1083 
1084  // Pop all of the recorded ignored operations that are no longer valid.
1085  while (ignoredOps.size() != state.numIgnoredOperations)
1086  ignoredOps.pop_back();
1087 
1088  while (replacedOps.size() != state.numReplacedOps)
1089  replacedOps.pop_back();
1090 }
1091 
1092 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1093  for (auto &rewrite :
1094  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1095  rewrite->rollback();
1096  rewrites.resize(numRewritesToKeep);
1097 }
1098 
1100  StringRef valueDiagTag, std::optional<Location> inputLoc,
1101  PatternRewriter &rewriter, ValueRange values,
1102  SmallVectorImpl<Value> &remapped) {
1103  remapped.reserve(llvm::size(values));
1104 
1105  for (const auto &it : llvm::enumerate(values)) {
1106  Value operand = it.value();
1107  Type origType = operand.getType();
1108  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1109 
1110  if (!currentTypeConverter) {
1111  // The current pattern does not have a type converter. I.e., it does not
1112  // distinguish between legal and illegal types. For each operand, simply
1113  // pass through the most recently mapped value.
1114  remapped.push_back(mapping.lookupOrDefault(operand));
1115  continue;
1116  }
1117 
1118  // If there is no legal conversion, fail to match this pattern.
1119  SmallVector<Type, 1> legalTypes;
1120  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1121  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1122  diag << "unable to convert type for " << valueDiagTag << " #"
1123  << it.index() << ", type was " << origType;
1124  });
1125  return failure();
1126  }
1127 
1128  if (legalTypes.size() != 1) {
1129  // TODO: Parts of the dialect conversion infrastructure do not support
1130  // 1->N type conversions yet. Therefore, if a type is converted to 0 or
1131  // multiple types, the only thing that we can do for now is passing
1132  // through the most recently mapped value. Fixing this requires
1133  // improvements to the `ConversionValueMapping` (to be able to store 1:N
1134  // mappings) and to the `ConversionPattern` adaptor handling (to be able
1135  // to pass multiple remapped values for a single operand to the adaptor).
1136  remapped.push_back(mapping.lookupOrDefault(operand));
1137  continue;
1138  }
1139 
1140  // Handle 1->1 type conversions.
1141  Type desiredType = legalTypes.front();
1142  // Try to find a mapped value with the desired type. (Or the operand itself
1143  // if the value is not mapped at all.)
1144  Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1145  if (newOperand.getType() != desiredType) {
1146  // If the looked up value's type does not have the desired type, it means
1147  // that the value was replaced with a value of different type and no
1148  // source materialization was created yet.
1150  MaterializationKind::Target, computeInsertPoint(newOperand),
1151  operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
1152  /*originalType=*/origType, currentTypeConverter);
1153  mapping.map(newOperand, castValue);
1154  newOperand = castValue;
1155  }
1156  remapped.push_back(newOperand);
1157  }
1158  return success();
1159 }
1160 
1162  // Check to see if this operation is ignored or was replaced.
1163  return replacedOps.count(op) || ignoredOps.count(op);
1164 }
1165 
1167  // Check to see if this operation was replaced.
1168  return replacedOps.count(op);
1169 }
1170 
1171 //===----------------------------------------------------------------------===//
1172 // Type Conversion
1173 
1175  ConversionPatternRewriter &rewriter, Region *region,
1176  const TypeConverter &converter,
1177  TypeConverter::SignatureConversion *entryConversion) {
1178  regionToConverter[region] = &converter;
1179  if (region->empty())
1180  return nullptr;
1181 
1182  // Convert the arguments of each non-entry block within the region.
1183  for (Block &block :
1184  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1185  // Compute the signature for the block with the provided converter.
1186  std::optional<TypeConverter::SignatureConversion> conversion =
1187  converter.convertBlockSignature(&block);
1188  if (!conversion)
1189  return failure();
1190  // Convert the block with the computed signature.
1191  applySignatureConversion(rewriter, &block, &converter, *conversion);
1192  }
1193 
1194  // Convert the entry block. If an entry signature conversion was provided,
1195  // use that one. Otherwise, compute the signature with the type converter.
1196  if (entryConversion)
1197  return applySignatureConversion(rewriter, &region->front(), &converter,
1198  *entryConversion);
1199  std::optional<TypeConverter::SignatureConversion> conversion =
1200  converter.convertBlockSignature(&region->front());
1201  if (!conversion)
1202  return failure();
1203  return applySignatureConversion(rewriter, &region->front(), &converter,
1204  *conversion);
1205 }
1206 
1208  ConversionPatternRewriter &rewriter, Block *block,
1209  const TypeConverter *converter,
1210  TypeConverter::SignatureConversion &signatureConversion) {
1211  OpBuilder::InsertionGuard g(rewriter);
1212 
1213  // If no arguments are being changed or added, there is nothing to do.
1214  unsigned origArgCount = block->getNumArguments();
1215  auto convertedTypes = signatureConversion.getConvertedTypes();
1216  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1217  return block;
1218 
1219  // Compute the locations of all block arguments in the new block.
1220  SmallVector<Location> newLocs(convertedTypes.size(),
1221  rewriter.getUnknownLoc());
1222  for (unsigned i = 0; i < origArgCount; ++i) {
1223  auto inputMap = signatureConversion.getInputMapping(i);
1224  if (!inputMap || inputMap->replacementValue)
1225  continue;
1226  Location origLoc = block->getArgument(i).getLoc();
1227  for (unsigned j = 0; j < inputMap->size; ++j)
1228  newLocs[inputMap->inputNo + j] = origLoc;
1229  }
1230 
1231  // Insert a new block with the converted block argument types and move all ops
1232  // from the old block to the new block.
1233  Block *newBlock =
1234  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1235  convertedTypes, newLocs);
1236 
1237  // If a listener is attached to the dialect conversion, ops cannot be moved
1238  // to the destination block in bulk ("fast path"). This is because at the time
1239  // the notifications are sent, it is unknown which ops were moved. Instead,
1240  // ops should be moved one-by-one ("slow path"), so that a separate
1241  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1242  // a bit more efficient, so we try to do that when possible.
1243  bool fastPath = !config.listener;
1244  if (fastPath) {
1245  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1246  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1247  } else {
1248  while (!block->empty())
1249  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1250  }
1251 
1252  // Replace all uses of the old block with the new block.
1253  block->replaceAllUsesWith(newBlock);
1254 
1255  for (unsigned i = 0; i != origArgCount; ++i) {
1256  BlockArgument origArg = block->getArgument(i);
1257  Type origArgType = origArg.getType();
1258 
1259  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1260  signatureConversion.getInputMapping(i);
1261  if (!inputMap) {
1262  // This block argument was dropped and no replacement value was provided.
1263  // Materialize a replacement value "out of thin air".
1265  MaterializationKind::Source,
1266  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1267  /*inputs=*/ValueRange(),
1268  /*outputType=*/origArgType, /*originalType=*/Type(), converter);
1269  mapping.map(origArg, repl);
1270  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1271  continue;
1272  }
1273 
1274  if (Value repl = inputMap->replacementValue) {
1275  // This block argument was dropped and a replacement value was provided.
1276  assert(inputMap->size == 0 &&
1277  "invalid to provide a replacement value when the argument isn't "
1278  "dropped");
1279  mapping.map(origArg, repl);
1280  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1281  continue;
1282  }
1283 
1284  // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1285  // dialect conversion. Therefore, we need an argument materialization to
1286  // turn the replacement block arguments into a single SSA value that can be
1287  // used as a replacement.
1288  auto replArgs =
1289  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1292  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1293  /*inputs=*/replArgs, /*outputType=*/origArgType,
1294  /*originalType=*/Type(), converter);
1295  mapping.map(origArg, argMat);
1296 
1297  Type legalOutputType;
1298  if (converter) {
1299  legalOutputType = converter->convertType(origArgType);
1300  } else if (replArgs.size() == 1) {
1301  // When there is no type converter, assume that the new block argument
1302  // types are legal. This is reasonable to assume because they were
1303  // specified by the user.
1304  // FIXME: This won't work for 1->N conversions because multiple output
1305  // types are not supported in parts of the dialect conversion. In such a
1306  // case, we currently use the original block argument type (produced by
1307  // the argument materialization).
1308  legalOutputType = replArgs[0].getType();
1309  }
1310  if (legalOutputType && legalOutputType != origArgType) {
1312  MaterializationKind::Target, computeInsertPoint(argMat),
1313  origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType,
1314  /*originalType=*/origArgType, converter);
1315  mapping.map(argMat, targetMat);
1316  }
1317  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1318  }
1319 
1320  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
1321 
1322  // Erase the old block. (It is just unlinked for now and will be erased during
1323  // cleanup.)
1324  rewriter.eraseBlock(block);
1325 
1326  return newBlock;
1327 }
1328 
1329 //===----------------------------------------------------------------------===//
1330 // Materializations
1331 //===----------------------------------------------------------------------===//
1332 
1333 /// Build an unresolved materialization operation given an output type and set
1334 /// of input operands.
1336  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1337  ValueRange inputs, Type outputType, Type originalType,
1338  const TypeConverter *converter) {
1339  assert((!originalType || kind == MaterializationKind::Target) &&
1340  "original type is valid only for target materializations");
1341 
1342  // Avoid materializing an unnecessary cast.
1343  if (inputs.size() == 1 && inputs.front().getType() == outputType)
1344  return inputs.front();
1345 
1346  // Create an unresolved materialization. We use a new OpBuilder to avoid
1347  // tracking the materialization like we do for other operations.
1348  OpBuilder builder(outputType.getContext());
1349  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1350  auto convertOp =
1351  builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1352  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1353  originalType);
1354  return convertOp.getResult(0);
1355 }
1356 
1357 //===----------------------------------------------------------------------===//
1358 // Rewriter Notification Hooks
1359 
1361  Operation *op, OpBuilder::InsertPoint previous) {
1362  LLVM_DEBUG({
1363  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1364  << ")\n";
1365  });
1366  assert(!wasOpReplaced(op->getParentOp()) &&
1367  "attempting to insert into a block within a replaced/erased op");
1368 
1369  if (!previous.isSet()) {
1370  // This is a newly created op.
1371  appendRewrite<CreateOperationRewrite>(op);
1372  return;
1373  }
1374  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1375  ? nullptr
1376  : &*previous.getPoint();
1377  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1378 }
1379 
1381  ValueRange newValues) {
1382  assert(newValues.size() == op->getNumResults());
1383  assert(!ignoredOps.contains(op) && "operation was already replaced");
1384 
1385  // Create mappings for each of the new result values.
1386  for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
1387  if (!newValue) {
1388  // This result was dropped and no replacement value was provided.
1389  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1390  if (unresolvedMaterializations.contains(castOp)) {
1391  // Do not create another materializations if we are erasing a
1392  // materialization.
1393  continue;
1394  }
1395  }
1396 
1397  // Materialize a replacement value "out of thin air".
1398  newValue = buildUnresolvedMaterialization(
1399  MaterializationKind::Source, computeInsertPoint(result),
1400  result.getLoc(), /*inputs=*/ValueRange(),
1401  /*outputType=*/result.getType(), /*originalType=*/Type(),
1403  }
1404 
1405  // Remap, and check for any result type changes.
1406  mapping.map(result, newValue);
1407  }
1408 
1409  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1410 
1411  // Mark this operation and all nested ops as replaced.
1412  op->walk([&](Operation *op) { replacedOps.insert(op); });
1413 }
1414 
1416  appendRewrite<EraseBlockRewrite>(block);
1417 }
1418 
1420  Block *block, Region *previous, Region::iterator previousIt) {
1421  assert(!wasOpReplaced(block->getParentOp()) &&
1422  "attempting to insert into a region within a replaced/erased op");
1423  LLVM_DEBUG(
1424  {
1425  Operation *parent = block->getParentOp();
1426  if (parent) {
1427  logger.startLine() << "** Insert Block into : '" << parent->getName()
1428  << "'(" << parent << ")\n";
1429  } else {
1430  logger.startLine()
1431  << "** Insert Block into detached Region (nullptr parent op)'";
1432  }
1433  });
1434 
1435  if (!previous) {
1436  // This is a newly created block.
1437  appendRewrite<CreateBlockRewrite>(block);
1438  return;
1439  }
1440  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1441  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1442 }
1443 
1445  Block *block, Block *srcBlock, Block::iterator before) {
1446  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1447 }
1448 
1450  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1451  LLVM_DEBUG({
1453  reasonCallback(diag);
1454  logger.startLine() << "** Failure : " << diag.str() << "\n";
1455  if (config.notifyCallback)
1457  });
1458 }
1459 
1460 //===----------------------------------------------------------------------===//
1461 // ConversionPatternRewriter
1462 //===----------------------------------------------------------------------===//
1463 
1464 ConversionPatternRewriter::ConversionPatternRewriter(
1465  MLIRContext *ctx, const ConversionConfig &config)
1466  : PatternRewriter(ctx),
1467  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1468  setListener(impl.get());
1469 }
1470 
1472 
1474  assert(op && newOp && "expected non-null op");
1475  replaceOp(op, newOp->getResults());
1476 }
1477 
1479  assert(op->getNumResults() == newValues.size() &&
1480  "incorrect # of replacement values");
1481  LLVM_DEBUG({
1482  impl->logger.startLine()
1483  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1484  });
1485  impl->notifyOpReplaced(op, newValues);
1486 }
1487 
1489  LLVM_DEBUG({
1490  impl->logger.startLine()
1491  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1492  });
1493  SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1494  impl->notifyOpReplaced(op, nullRepls);
1495 }
1496 
1498  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1499  "attempting to erase a block within a replaced/erased op");
1500 
1501  // Mark all ops for erasure.
1502  for (Operation &op : *block)
1503  eraseOp(&op);
1504 
1505  // Unlink the block from its parent region. The block is kept in the rewrite
1506  // object and will be actually destroyed when rewrites are applied. This
1507  // allows us to keep the operations in the block live and undo the removal by
1508  // re-inserting the block.
1509  impl->notifyBlockIsBeingErased(block);
1510  block->getParent()->getBlocks().remove(block);
1511 }
1512 
1514  Block *block, TypeConverter::SignatureConversion &conversion,
1515  const TypeConverter *converter) {
1516  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1517  "attempting to apply a signature conversion to a block within a "
1518  "replaced/erased op");
1519  return impl->applySignatureConversion(*this, block, converter, conversion);
1520 }
1521 
1523  Region *region, const TypeConverter &converter,
1524  TypeConverter::SignatureConversion *entryConversion) {
1525  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1526  "attempting to apply a signature conversion to a block within a "
1527  "replaced/erased op");
1528  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1529 }
1530 
1532  Value to) {
1533  LLVM_DEBUG({
1534  Operation *parentOp = from.getOwner()->getParentOp();
1535  impl->logger.startLine() << "** Replace Argument : '" << from
1536  << "'(in region of '" << parentOp->getName()
1537  << "'(" << from.getOwner()->getParentOp() << ")\n";
1538  });
1539  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
1540  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1541 }
1542 
1544  SmallVector<Value> remappedValues;
1545  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1546  remappedValues)))
1547  return nullptr;
1548  return remappedValues.front();
1549 }
1550 
1551 LogicalResult
1553  SmallVectorImpl<Value> &results) {
1554  if (keys.empty())
1555  return success();
1556  return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1557  results);
1558 }
1559 
1561  Block::iterator before,
1562  ValueRange argValues) {
1563 #ifndef NDEBUG
1564  assert(argValues.size() == source->getNumArguments() &&
1565  "incorrect # of argument replacement values");
1566  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1567  "attempting to inline a block from a replaced/erased op");
1568  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1569  "attempting to inline a block into a replaced/erased op");
1570  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1571  // The source block will be deleted, so it should not have any users (i.e.,
1572  // there should be no predecessors).
1573  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1574  "expected 'source' to have no predecessors");
1575 #endif // NDEBUG
1576 
1577  // If a listener is attached to the dialect conversion, ops cannot be moved
1578  // to the destination block in bulk ("fast path"). This is because at the time
1579  // the notifications are sent, it is unknown which ops were moved. Instead,
1580  // ops should be moved one-by-one ("slow path"), so that a separate
1581  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1582  // a bit more efficient, so we try to do that when possible.
1583  bool fastPath = !impl->config.listener;
1584 
1585  if (fastPath)
1586  impl->notifyBlockBeingInlined(dest, source, before);
1587 
1588  // Replace all uses of block arguments.
1589  for (auto it : llvm::zip(source->getArguments(), argValues))
1590  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1591 
1592  if (fastPath) {
1593  // Move all ops at once.
1594  dest->getOperations().splice(before, source->getOperations());
1595  } else {
1596  // Move op by op.
1597  while (!source->empty())
1598  moveOpBefore(&source->front(), dest, before);
1599  }
1600 
1601  // Erase the source block.
1602  eraseBlock(source);
1603 }
1604 
1606  assert(!impl->wasOpReplaced(op) &&
1607  "attempting to modify a replaced/erased op");
1608 #ifndef NDEBUG
1609  impl->pendingRootUpdates.insert(op);
1610 #endif
1611  impl->appendRewrite<ModifyOperationRewrite>(op);
1612 }
1613 
1615  assert(!impl->wasOpReplaced(op) &&
1616  "attempting to modify a replaced/erased op");
1618  // There is nothing to do here, we only need to track the operation at the
1619  // start of the update.
1620 #ifndef NDEBUG
1621  assert(impl->pendingRootUpdates.erase(op) &&
1622  "operation did not have a pending in-place update");
1623 #endif
1624 }
1625 
1627 #ifndef NDEBUG
1628  assert(impl->pendingRootUpdates.erase(op) &&
1629  "operation did not have a pending in-place update");
1630 #endif
1631  // Erase the last update for this operation.
1632  auto it = llvm::find_if(
1633  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1634  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1635  return modifyRewrite && modifyRewrite->getOperation() == op;
1636  });
1637  assert(it != impl->rewrites.rend() && "no root update started on op");
1638  (*it)->rollback();
1639  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1640  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1641 }
1642 
1644  return *impl;
1645 }
1646 
1647 //===----------------------------------------------------------------------===//
1648 // ConversionPattern
1649 //===----------------------------------------------------------------------===//
1650 
1651 LogicalResult
1653  PatternRewriter &rewriter) const {
1654  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1655  auto &rewriterImpl = dialectRewriter.getImpl();
1656 
1657  // Track the current conversion pattern type converter in the rewriter.
1658  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1659  getTypeConverter());
1660 
1661  // Remap the operands of the operation.
1662  SmallVector<Value, 4> operands;
1663  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1664  op->getOperands(), operands))) {
1665  return failure();
1666  }
1667  return matchAndRewrite(op, operands, dialectRewriter);
1668 }
1669 
1670 //===----------------------------------------------------------------------===//
1671 // OperationLegalizer
1672 //===----------------------------------------------------------------------===//
1673 
1674 namespace {
1675 /// A set of rewrite patterns that can be used to legalize a given operation.
1676 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1677 
1678 /// This class defines a recursive operation legalizer.
1679 class OperationLegalizer {
1680 public:
1681  using LegalizationAction = ConversionTarget::LegalizationAction;
1682 
1683  OperationLegalizer(const ConversionTarget &targetInfo,
1684  const FrozenRewritePatternSet &patterns,
1685  const ConversionConfig &config);
1686 
1687  /// Returns true if the given operation is known to be illegal on the target.
1688  bool isIllegal(Operation *op) const;
1689 
1690  /// Attempt to legalize the given operation. Returns success if the operation
1691  /// was legalized, failure otherwise.
1692  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1693 
1694  /// Returns the conversion target in use by the legalizer.
1695  const ConversionTarget &getTarget() { return target; }
1696 
1697 private:
1698  /// Attempt to legalize the given operation by folding it.
1699  LogicalResult legalizeWithFold(Operation *op,
1700  ConversionPatternRewriter &rewriter);
1701 
1702  /// Attempt to legalize the given operation by applying a pattern. Returns
1703  /// success if the operation was legalized, failure otherwise.
1704  LogicalResult legalizeWithPattern(Operation *op,
1705  ConversionPatternRewriter &rewriter);
1706 
1707  /// Return true if the given pattern may be applied to the given operation,
1708  /// false otherwise.
1709  bool canApplyPattern(Operation *op, const Pattern &pattern,
1710  ConversionPatternRewriter &rewriter);
1711 
1712  /// Legalize the resultant IR after successfully applying the given pattern.
1713  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1714  ConversionPatternRewriter &rewriter,
1715  RewriterState &curState);
1716 
1717  /// Legalizes the actions registered during the execution of a pattern.
1718  LogicalResult
1719  legalizePatternBlockRewrites(Operation *op,
1720  ConversionPatternRewriter &rewriter,
1722  RewriterState &state, RewriterState &newState);
1723  LogicalResult legalizePatternCreatedOperations(
1725  RewriterState &state, RewriterState &newState);
1726  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1728  RewriterState &state,
1729  RewriterState &newState);
1730 
1731  //===--------------------------------------------------------------------===//
1732  // Cost Model
1733  //===--------------------------------------------------------------------===//
1734 
1735  /// Build an optimistic legalization graph given the provided patterns. This
1736  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1737  /// patterns for operations that are not directly legal, but may be
1738  /// transitively legal for the current target given the provided patterns.
1739  void buildLegalizationGraph(
1740  LegalizationPatterns &anyOpLegalizerPatterns,
1742 
1743  /// Compute the benefit of each node within the computed legalization graph.
1744  /// This orders the patterns within 'legalizerPatterns' based upon two
1745  /// criteria:
1746  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1747  /// represent the more direct mapping to the target.
1748  /// 2) When comparing patterns with the same legalization depth, prefer the
1749  /// pattern with the highest PatternBenefit. This allows for users to
1750  /// prefer specific legalizations over others.
1751  void computeLegalizationGraphBenefit(
1752  LegalizationPatterns &anyOpLegalizerPatterns,
1754 
1755  /// Compute the legalization depth when legalizing an operation of the given
1756  /// type.
1757  unsigned computeOpLegalizationDepth(
1758  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1760 
1761  /// Apply the conversion cost model to the given set of patterns, and return
1762  /// the smallest legalization depth of any of the patterns. See
1763  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1764  unsigned applyCostModelToPatterns(
1765  LegalizationPatterns &patterns,
1766  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1768 
1769  /// The current set of patterns that have been applied.
1770  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1771 
1772  /// The legalization information provided by the target.
1773  const ConversionTarget &target;
1774 
1775  /// The pattern applicator to use for conversions.
1776  PatternApplicator applicator;
1777 
1778  /// Dialect conversion configuration.
1779  const ConversionConfig &config;
1780 };
1781 } // namespace
1782 
1783 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1784  const FrozenRewritePatternSet &patterns,
1785  const ConversionConfig &config)
1786  : target(targetInfo), applicator(patterns), config(config) {
1787  // The set of patterns that can be applied to illegal operations to transform
1788  // them into legal ones.
1790  LegalizationPatterns anyOpLegalizerPatterns;
1791 
1792  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1793  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1794 }
1795 
1796 bool OperationLegalizer::isIllegal(Operation *op) const {
1797  return target.isIllegal(op);
1798 }
1799 
1800 LogicalResult
1801 OperationLegalizer::legalize(Operation *op,
1802  ConversionPatternRewriter &rewriter) {
1803 #ifndef NDEBUG
1804  const char *logLineComment =
1805  "//===-------------------------------------------===//\n";
1806 
1807  auto &logger = rewriter.getImpl().logger;
1808 #endif
1809  LLVM_DEBUG({
1810  logger.getOStream() << "\n";
1811  logger.startLine() << logLineComment;
1812  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1813  << op << ") {\n";
1814  logger.indent();
1815 
1816  // If the operation has no regions, just print it here.
1817  if (op->getNumRegions() == 0) {
1818  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1819  logger.getOStream() << "\n\n";
1820  }
1821  });
1822 
1823  // Check if this operation is legal on the target.
1824  if (auto legalityInfo = target.isLegal(op)) {
1825  LLVM_DEBUG({
1826  logSuccess(
1827  logger, "operation marked legal by the target{0}",
1828  legalityInfo->isRecursivelyLegal
1829  ? "; NOTE: operation is recursively legal; skipping internals"
1830  : "");
1831  logger.startLine() << logLineComment;
1832  });
1833 
1834  // If this operation is recursively legal, mark its children as ignored so
1835  // that we don't consider them for legalization.
1836  if (legalityInfo->isRecursivelyLegal) {
1837  op->walk([&](Operation *nested) {
1838  if (op != nested)
1839  rewriter.getImpl().ignoredOps.insert(nested);
1840  });
1841  }
1842 
1843  return success();
1844  }
1845 
1846  // Check to see if the operation is ignored and doesn't need to be converted.
1847  if (rewriter.getImpl().isOpIgnored(op)) {
1848  LLVM_DEBUG({
1849  logSuccess(logger, "operation marked 'ignored' during conversion");
1850  logger.startLine() << logLineComment;
1851  });
1852  return success();
1853  }
1854 
1855  // If the operation isn't legal, try to fold it in-place.
1856  // TODO: Should we always try to do this, even if the op is
1857  // already legal?
1858  if (succeeded(legalizeWithFold(op, rewriter))) {
1859  LLVM_DEBUG({
1860  logSuccess(logger, "operation was folded");
1861  logger.startLine() << logLineComment;
1862  });
1863  return success();
1864  }
1865 
1866  // Otherwise, we need to apply a legalization pattern to this operation.
1867  if (succeeded(legalizeWithPattern(op, rewriter))) {
1868  LLVM_DEBUG({
1869  logSuccess(logger, "");
1870  logger.startLine() << logLineComment;
1871  });
1872  return success();
1873  }
1874 
1875  LLVM_DEBUG({
1876  logFailure(logger, "no matched legalization pattern");
1877  logger.startLine() << logLineComment;
1878  });
1879  return failure();
1880 }
1881 
1882 LogicalResult
1883 OperationLegalizer::legalizeWithFold(Operation *op,
1884  ConversionPatternRewriter &rewriter) {
1885  auto &rewriterImpl = rewriter.getImpl();
1886  RewriterState curState = rewriterImpl.getCurrentState();
1887 
1888  LLVM_DEBUG({
1889  rewriterImpl.logger.startLine() << "* Fold {\n";
1890  rewriterImpl.logger.indent();
1891  });
1892 
1893  // Try to fold the operation.
1894  SmallVector<Value, 2> replacementValues;
1895  rewriter.setInsertionPoint(op);
1896  if (failed(rewriter.tryFold(op, replacementValues))) {
1897  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1898  return failure();
1899  }
1900  // An empty list of replacement values indicates that the fold was in-place.
1901  // As the operation changed, a new legalization needs to be attempted.
1902  if (replacementValues.empty())
1903  return legalize(op, rewriter);
1904 
1905  // Insert a replacement for 'op' with the folded replacement values.
1906  rewriter.replaceOp(op, replacementValues);
1907 
1908  // Recursively legalize any new constant operations.
1909  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
1910  i != e; ++i) {
1911  auto *createOp =
1912  dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
1913  if (!createOp)
1914  continue;
1915  if (failed(legalize(createOp->getOperation(), rewriter))) {
1916  LLVM_DEBUG(logFailure(rewriterImpl.logger,
1917  "failed to legalize generated constant '{0}'",
1918  createOp->getOperation()->getName()));
1919  rewriterImpl.resetState(curState);
1920  return failure();
1921  }
1922  }
1923 
1924  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1925  return success();
1926 }
1927 
1928 LogicalResult
1929 OperationLegalizer::legalizeWithPattern(Operation *op,
1930  ConversionPatternRewriter &rewriter) {
1931  auto &rewriterImpl = rewriter.getImpl();
1932 
1933  // Functor that returns if the given pattern may be applied.
1934  auto canApply = [&](const Pattern &pattern) {
1935  bool canApply = canApplyPattern(op, pattern, rewriter);
1936  if (canApply && config.listener)
1937  config.listener->notifyPatternBegin(pattern, op);
1938  return canApply;
1939  };
1940 
1941  // Functor that cleans up the rewriter state after a pattern failed to match.
1942  RewriterState curState = rewriterImpl.getCurrentState();
1943  auto onFailure = [&](const Pattern &pattern) {
1944  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
1945  LLVM_DEBUG({
1946  logFailure(rewriterImpl.logger, "pattern failed to match");
1947  if (rewriterImpl.config.notifyCallback) {
1949  diag << "Failed to apply pattern \"" << pattern.getDebugName()
1950  << "\" on op:\n"
1951  << *op;
1952  rewriterImpl.config.notifyCallback(diag);
1953  }
1954  });
1955  if (config.listener)
1956  config.listener->notifyPatternEnd(pattern, failure());
1957  rewriterImpl.resetState(curState);
1958  appliedPatterns.erase(&pattern);
1959  };
1960 
1961  // Functor that performs additional legalization when a pattern is
1962  // successfully applied.
1963  auto onSuccess = [&](const Pattern &pattern) {
1964  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
1965  auto result = legalizePatternResult(op, pattern, rewriter, curState);
1966  appliedPatterns.erase(&pattern);
1967  if (failed(result))
1968  rewriterImpl.resetState(curState);
1969  if (config.listener)
1970  config.listener->notifyPatternEnd(pattern, result);
1971  return result;
1972  };
1973 
1974  // Try to match and rewrite a pattern on this operation.
1975  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1976  onSuccess);
1977 }
1978 
1979 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1980  ConversionPatternRewriter &rewriter) {
1981  LLVM_DEBUG({
1982  auto &os = rewriter.getImpl().logger;
1983  os.getOStream() << "\n";
1984  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1985  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
1986  os.getOStream() << ")' {\n";
1987  os.indent();
1988  });
1989 
1990  // Ensure that we don't cycle by not allowing the same pattern to be
1991  // applied twice in the same recursion stack if it is not known to be safe.
1992  if (!pattern.hasBoundedRewriteRecursion() &&
1993  !appliedPatterns.insert(&pattern).second) {
1994  LLVM_DEBUG(
1995  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
1996  return false;
1997  }
1998  return true;
1999 }
2000 
2001 LogicalResult
2002 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2003  ConversionPatternRewriter &rewriter,
2004  RewriterState &curState) {
2005  auto &impl = rewriter.getImpl();
2006 
2007 #ifndef NDEBUG
2008  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2009  // Check that the root was either replaced or updated in place.
2010  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2011  auto replacedRoot = [&] {
2012  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2013  };
2014  auto updatedRootInPlace = [&] {
2015  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2016  };
2017  assert((replacedRoot() || updatedRootInPlace()) &&
2018  "expected pattern to replace the root operation");
2019 #endif // NDEBUG
2020 
2021  // Legalize each of the actions registered during application.
2022  RewriterState newState = impl.getCurrentState();
2023  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2024  newState)) ||
2025  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2026  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2027  newState))) {
2028  return failure();
2029  }
2030 
2031  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2032  return success();
2033 }
2034 
2035 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2036  Operation *op, ConversionPatternRewriter &rewriter,
2037  ConversionPatternRewriterImpl &impl, RewriterState &state,
2038  RewriterState &newState) {
2039  SmallPtrSet<Operation *, 16> operationsToIgnore;
2040 
2041  // If the pattern moved or created any blocks, make sure the types of block
2042  // arguments get legalized.
2043  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2044  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2045  if (!rewrite)
2046  continue;
2047  Block *block = rewrite->getBlock();
2048  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2049  ReplaceBlockArgRewrite>(rewrite))
2050  continue;
2051  // Only check blocks outside of the current operation.
2052  Operation *parentOp = block->getParentOp();
2053  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2054  continue;
2055 
2056  // If the region of the block has a type converter, try to convert the block
2057  // directly.
2058  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2059  std::optional<TypeConverter::SignatureConversion> conversion =
2060  converter->convertBlockSignature(block);
2061  if (!conversion) {
2062  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2063  "block"));
2064  return failure();
2065  }
2066  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2067  continue;
2068  }
2069 
2070  // Otherwise, check that this operation isn't one generated by this pattern.
2071  // This is because we will attempt to legalize the parent operation, and
2072  // blocks in regions created by this pattern will already be legalized later
2073  // on. If we haven't built the set yet, build it now.
2074  if (operationsToIgnore.empty()) {
2075  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2076  ++i) {
2077  auto *createOp =
2078  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2079  if (!createOp)
2080  continue;
2081  operationsToIgnore.insert(createOp->getOperation());
2082  }
2083  }
2084 
2085  // If this operation should be considered for re-legalization, try it.
2086  if (operationsToIgnore.insert(parentOp).second &&
2087  failed(legalize(parentOp, rewriter))) {
2088  LLVM_DEBUG(logFailure(impl.logger,
2089  "operation '{0}'({1}) became illegal after rewrite",
2090  parentOp->getName(), parentOp));
2091  return failure();
2092  }
2093  }
2094  return success();
2095 }
2096 
2097 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2099  RewriterState &state, RewriterState &newState) {
2100  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2101  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2102  if (!createOp)
2103  continue;
2104  Operation *op = createOp->getOperation();
2105  if (failed(legalize(op, rewriter))) {
2106  LLVM_DEBUG(logFailure(impl.logger,
2107  "failed to legalize generated operation '{0}'({1})",
2108  op->getName(), op));
2109  return failure();
2110  }
2111  }
2112  return success();
2113 }
2114 
2115 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2117  RewriterState &state, RewriterState &newState) {
2118  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2119  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2120  if (!rewrite)
2121  continue;
2122  Operation *op = rewrite->getOperation();
2123  if (failed(legalize(op, rewriter))) {
2124  LLVM_DEBUG(logFailure(
2125  impl.logger, "failed to legalize operation updated in-place '{0}'",
2126  op->getName()));
2127  return failure();
2128  }
2129  }
2130  return success();
2131 }
2132 
2133 //===----------------------------------------------------------------------===//
2134 // Cost Model
2135 
2136 void OperationLegalizer::buildLegalizationGraph(
2137  LegalizationPatterns &anyOpLegalizerPatterns,
2138  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2139  // A mapping between an operation and a set of operations that can be used to
2140  // generate it.
2142  // A mapping between an operation and any currently invalid patterns it has.
2144  // A worklist of patterns to consider for legality.
2145  SetVector<const Pattern *> patternWorklist;
2146 
2147  // Build the mapping from operations to the parent ops that may generate them.
2148  applicator.walkAllPatterns([&](const Pattern &pattern) {
2149  std::optional<OperationName> root = pattern.getRootKind();
2150 
2151  // If the pattern has no specific root, we can't analyze the relationship
2152  // between the root op and generated operations. Given that, add all such
2153  // patterns to the legalization set.
2154  if (!root) {
2155  anyOpLegalizerPatterns.push_back(&pattern);
2156  return;
2157  }
2158 
2159  // Skip operations that are always known to be legal.
2160  if (target.getOpAction(*root) == LegalizationAction::Legal)
2161  return;
2162 
2163  // Add this pattern to the invalid set for the root op and record this root
2164  // as a parent for any generated operations.
2165  invalidPatterns[*root].insert(&pattern);
2166  for (auto op : pattern.getGeneratedOps())
2167  parentOps[op].insert(*root);
2168 
2169  // Add this pattern to the worklist.
2170  patternWorklist.insert(&pattern);
2171  });
2172 
2173  // If there are any patterns that don't have a specific root kind, we can't
2174  // make direct assumptions about what operations will never be legalized.
2175  // Note: Technically we could, but it would require an analysis that may
2176  // recurse into itself. It would be better to perform this kind of filtering
2177  // at a higher level than here anyways.
2178  if (!anyOpLegalizerPatterns.empty()) {
2179  for (const Pattern *pattern : patternWorklist)
2180  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2181  return;
2182  }
2183 
2184  while (!patternWorklist.empty()) {
2185  auto *pattern = patternWorklist.pop_back_val();
2186 
2187  // Check to see if any of the generated operations are invalid.
2188  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2189  std::optional<LegalizationAction> action = target.getOpAction(op);
2190  return !legalizerPatterns.count(op) &&
2191  (!action || action == LegalizationAction::Illegal);
2192  }))
2193  continue;
2194 
2195  // Otherwise, if all of the generated operation are valid, this op is now
2196  // legal so add all of the child patterns to the worklist.
2197  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2198  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2199 
2200  // Add any invalid patterns of the parent operations to see if they have now
2201  // become legal.
2202  for (auto op : parentOps[*pattern->getRootKind()])
2203  patternWorklist.set_union(invalidPatterns[op]);
2204  }
2205 }
2206 
2207 void OperationLegalizer::computeLegalizationGraphBenefit(
2208  LegalizationPatterns &anyOpLegalizerPatterns,
2209  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2210  // The smallest pattern depth, when legalizing an operation.
2211  DenseMap<OperationName, unsigned> minOpPatternDepth;
2212 
2213  // For each operation that is transitively legal, compute a cost for it.
2214  for (auto &opIt : legalizerPatterns)
2215  if (!minOpPatternDepth.count(opIt.first))
2216  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2217  legalizerPatterns);
2218 
2219  // Apply the cost model to the patterns that can match any operation. Those
2220  // with a specific operation type are already resolved when computing the op
2221  // legalization depth.
2222  if (!anyOpLegalizerPatterns.empty())
2223  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2224  legalizerPatterns);
2225 
2226  // Apply a cost model to the pattern applicator. We order patterns first by
2227  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2228  // decreasing benefit.
2229  applicator.applyCostModel([&](const Pattern &pattern) {
2230  ArrayRef<const Pattern *> orderedPatternList;
2231  if (std::optional<OperationName> rootName = pattern.getRootKind())
2232  orderedPatternList = legalizerPatterns[*rootName];
2233  else
2234  orderedPatternList = anyOpLegalizerPatterns;
2235 
2236  // If the pattern is not found, then it was removed and cannot be matched.
2237  auto *it = llvm::find(orderedPatternList, &pattern);
2238  if (it == orderedPatternList.end())
2240 
2241  // Patterns found earlier in the list have higher benefit.
2242  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2243  });
2244 }
2245 
2246 unsigned OperationLegalizer::computeOpLegalizationDepth(
2247  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2248  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2249  // Check for existing depth.
2250  auto depthIt = minOpPatternDepth.find(op);
2251  if (depthIt != minOpPatternDepth.end())
2252  return depthIt->second;
2253 
2254  // If a mapping for this operation does not exist, then this operation
2255  // is always legal. Return 0 as the depth for a directly legal operation.
2256  auto opPatternsIt = legalizerPatterns.find(op);
2257  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2258  return 0u;
2259 
2260  // Record this initial depth in case we encounter this op again when
2261  // recursively computing the depth.
2262  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2263 
2264  // Apply the cost model to the operation patterns, and update the minimum
2265  // depth.
2266  unsigned minDepth = applyCostModelToPatterns(
2267  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2268  minOpPatternDepth[op] = minDepth;
2269  return minDepth;
2270 }
2271 
2272 unsigned OperationLegalizer::applyCostModelToPatterns(
2273  LegalizationPatterns &patterns,
2274  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2275  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2276  unsigned minDepth = std::numeric_limits<unsigned>::max();
2277 
2278  // Compute the depth for each pattern within the set.
2280  patternsByDepth.reserve(patterns.size());
2281  for (const Pattern *pattern : patterns) {
2282  unsigned depth = 1;
2283  for (auto generatedOp : pattern->getGeneratedOps()) {
2284  unsigned generatedOpDepth = computeOpLegalizationDepth(
2285  generatedOp, minOpPatternDepth, legalizerPatterns);
2286  depth = std::max(depth, generatedOpDepth + 1);
2287  }
2288  patternsByDepth.emplace_back(pattern, depth);
2289 
2290  // Update the minimum depth of the pattern list.
2291  minDepth = std::min(minDepth, depth);
2292  }
2293 
2294  // If the operation only has one legalization pattern, there is no need to
2295  // sort them.
2296  if (patternsByDepth.size() == 1)
2297  return minDepth;
2298 
2299  // Sort the patterns by those likely to be the most beneficial.
2300  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2301  [](const std::pair<const Pattern *, unsigned> &lhs,
2302  const std::pair<const Pattern *, unsigned> &rhs) {
2303  // First sort by the smaller pattern legalization
2304  // depth.
2305  if (lhs.second != rhs.second)
2306  return lhs.second < rhs.second;
2307 
2308  // Then sort by the larger pattern benefit.
2309  auto lhsBenefit = lhs.first->getBenefit();
2310  auto rhsBenefit = rhs.first->getBenefit();
2311  return lhsBenefit > rhsBenefit;
2312  });
2313 
2314  // Update the legalization pattern to use the new sorted list.
2315  patterns.clear();
2316  for (auto &patternIt : patternsByDepth)
2317  patterns.push_back(patternIt.first);
2318  return minDepth;
2319 }
2320 
2321 //===----------------------------------------------------------------------===//
2322 // OperationConverter
2323 //===----------------------------------------------------------------------===//
2324 namespace {
2325 enum OpConversionMode {
2326  /// In this mode, the conversion will ignore failed conversions to allow
2327  /// illegal operations to co-exist in the IR.
2328  Partial,
2329 
2330  /// In this mode, all operations must be legal for the given target for the
2331  /// conversion to succeed.
2332  Full,
2333 
2334  /// In this mode, operations are analyzed for legality. No actual rewrites are
2335  /// applied to the operations on success.
2336  Analysis,
2337 };
2338 } // namespace
2339 
2340 namespace mlir {
2341 // This class converts operations to a given conversion target via a set of
2342 // rewrite patterns. The conversion behaves differently depending on the
2343 // conversion mode.
2345  explicit OperationConverter(const ConversionTarget &target,
2346  const FrozenRewritePatternSet &patterns,
2347  const ConversionConfig &config,
2348  OpConversionMode mode)
2349  : config(config), opLegalizer(target, patterns, this->config),
2350  mode(mode) {}
2351 
2352  /// Converts the given operations to the conversion target.
2353  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2354 
2355 private:
2356  /// Converts an operation with the given rewriter.
2357  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2358 
2359  /// This method is called after the conversion process to legalize any
2360  /// remaining artifacts and complete the conversion.
2361  void finalize(ConversionPatternRewriter &rewriter);
2362 
2363  /// Dialect conversion configuration.
2364  ConversionConfig config;
2365 
2366  /// The legalizer to use when converting operations.
2367  OperationLegalizer opLegalizer;
2368 
2369  /// The conversion mode to use when legalizing operations.
2370  OpConversionMode mode;
2371 };
2372 } // namespace mlir
2373 
2374 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2375  Operation *op) {
2376  // Legalize the given operation.
2377  if (failed(opLegalizer.legalize(op, rewriter))) {
2378  // Handle the case of a failed conversion for each of the different modes.
2379  // Full conversions expect all operations to be converted.
2380  if (mode == OpConversionMode::Full)
2381  return op->emitError()
2382  << "failed to legalize operation '" << op->getName() << "'";
2383  // Partial conversions allow conversions to fail iff the operation was not
2384  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2385  // set, non-legalizable ops are added to that set.
2386  if (mode == OpConversionMode::Partial) {
2387  if (opLegalizer.isIllegal(op))
2388  return op->emitError()
2389  << "failed to legalize operation '" << op->getName()
2390  << "' that was explicitly marked illegal";
2391  if (config.unlegalizedOps)
2392  config.unlegalizedOps->insert(op);
2393  }
2394  } else if (mode == OpConversionMode::Analysis) {
2395  // Analysis conversions don't fail if any operations fail to legalize,
2396  // they are only interested in the operations that were successfully
2397  // legalized.
2398  if (config.legalizableOps)
2399  config.legalizableOps->insert(op);
2400  }
2401  return success();
2402 }
2403 
2404 static LogicalResult
2406  UnresolvedMaterializationRewrite *rewrite) {
2407  UnrealizedConversionCastOp op = rewrite->getOperation();
2408  assert(!op.use_empty() &&
2409  "expected that dead materializations have already been DCE'd");
2410  Operation::operand_range inputOperands = op.getOperands();
2411  Type outputType = op.getResultTypes()[0];
2412 
2413  // Try to materialize the conversion.
2414  if (const TypeConverter *converter = rewrite->getConverter()) {
2415  rewriter.setInsertionPoint(op);
2416  Value newMaterialization;
2417  switch (rewrite->getMaterializationKind()) {
2419  // Try to materialize an argument conversion.
2420  newMaterialization = converter->materializeArgumentConversion(
2421  rewriter, op->getLoc(), outputType, inputOperands);
2422  if (newMaterialization)
2423  break;
2424  // If an argument materialization failed, fallback to trying a target
2425  // materialization.
2426  [[fallthrough]];
2427  case MaterializationKind::Target:
2428  newMaterialization = converter->materializeTargetConversion(
2429  rewriter, op->getLoc(), outputType, inputOperands,
2430  rewrite->getOriginalType());
2431  break;
2432  case MaterializationKind::Source:
2433  newMaterialization = converter->materializeSourceConversion(
2434  rewriter, op->getLoc(), outputType, inputOperands);
2435  break;
2436  }
2437  if (newMaterialization) {
2438  assert(newMaterialization.getType() == outputType &&
2439  "materialization callback produced value of incorrect type");
2440  rewriter.replaceOp(op, newMaterialization);
2441  return success();
2442  }
2443  }
2444 
2446  << "failed to legalize unresolved materialization "
2447  "from ("
2448  << inputOperands.getTypes() << ") to " << outputType
2449  << " that remained live after conversion";
2450  diag.attachNote(op->getUsers().begin()->getLoc())
2451  << "see existing live user here: " << *op->getUsers().begin();
2452  return failure();
2453 }
2454 
2456  if (ops.empty())
2457  return success();
2458  const ConversionTarget &target = opLegalizer.getTarget();
2459 
2460  // Compute the set of operations and blocks to convert.
2461  SmallVector<Operation *> toConvert;
2462  for (auto *op : ops) {
2464  [&](Operation *op) {
2465  toConvert.push_back(op);
2466  // Don't check this operation's children for conversion if the
2467  // operation is recursively legal.
2468  auto legalityInfo = target.isLegal(op);
2469  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2470  return WalkResult::skip();
2471  return WalkResult::advance();
2472  });
2473  }
2474 
2475  // Convert each operation and discard rewrites on failure.
2476  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2477  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2478 
2479  for (auto *op : toConvert)
2480  if (failed(convert(rewriter, op)))
2481  return rewriterImpl.undoRewrites(), failure();
2482 
2483  // Now that all of the operations have been converted, finalize the conversion
2484  // process to ensure any lingering conversion artifacts are cleaned up and
2485  // legalized.
2486  finalize(rewriter);
2487 
2488  // After a successful conversion, apply rewrites.
2489  rewriterImpl.applyRewrites();
2490 
2491  // Gather all unresolved materializations.
2494  &materializations = rewriterImpl.unresolvedMaterializations;
2495  for (auto it : materializations) {
2496  if (rewriterImpl.eraseRewriter.wasErased(it.first))
2497  continue;
2498  allCastOps.push_back(it.first);
2499  }
2500 
2501  // Reconcile all UnrealizedConversionCastOps that were inserted by the
2502  // dialect conversion frameworks. (Not the one that were inserted by
2503  // patterns.)
2504  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2505  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2506 
2507  // Try to legalize all unresolved materializations.
2508  if (config.buildMaterializations) {
2509  IRRewriter rewriter(rewriterImpl.context, config.listener);
2510  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2511  auto it = materializations.find(castOp);
2512  assert(it != materializations.end() && "inconsistent state");
2513  if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2514  return failure();
2515  }
2516  }
2517 
2518  return success();
2519 }
2520 
2521 /// Finds a user of the given value, or of any other value that the given value
2522 /// replaced, that was not replaced in the conversion process.
2524  Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2525  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2526  SmallVector<Value> worklist = {initialValue};
2527  while (!worklist.empty()) {
2528  Value value = worklist.pop_back_val();
2529 
2530  // Walk the users of this value to see if there are any live users that
2531  // weren't replaced during conversion.
2532  auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2533  return rewriterImpl.isOpIgnored(user);
2534  });
2535  if (liveUserIt != value.user_end())
2536  return *liveUserIt;
2537  auto mapIt = inverseMapping.find(value);
2538  if (mapIt != inverseMapping.end())
2539  worklist.append(mapIt->second);
2540  }
2541  return nullptr;
2542 }
2543 
2544 /// Helper function that returns the replaced values and the type converter if
2545 /// the given rewrite object is an "operation replacement" or a "block type
2546 /// conversion" (which corresponds to a "block replacement"). Otherwise, return
2547 /// an empty ValueRange and a null type converter pointer.
2548 static std::pair<ValueRange, const TypeConverter *>
2550  if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
2551  return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
2552  if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
2553  return {blockRewrite->getOrigBlock()->getArguments(),
2554  blockRewrite->getConverter()};
2555  return {};
2556 }
2557 
2558 void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2559  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2560  DenseMap<Value, SmallVector<Value>> inverseMapping =
2561  rewriterImpl.mapping.getInverse();
2562 
2563  // Process requested value replacements.
2564  for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
2565  ValueRange replacedValues;
2566  const TypeConverter *converter;
2567  std::tie(replacedValues, converter) =
2568  getReplacedValues(rewriterImpl.rewrites[i].get());
2569  for (Value originalValue : replacedValues) {
2570  // If the type of this value changed and the value is still live, we need
2571  // to materialize a conversion.
2572  if (rewriterImpl.mapping.lookupOrNull(originalValue,
2573  originalValue.getType()))
2574  continue;
2575  Operation *liveUser =
2576  findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
2577  if (!liveUser)
2578  continue;
2579 
2580  // Legalize this value replacement.
2581  Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
2582  assert(newValue && "replacement value not found");
2583  Value castValue = rewriterImpl.buildUnresolvedMaterialization(
2584  MaterializationKind::Source, computeInsertPoint(newValue),
2585  originalValue.getLoc(),
2586  /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
2587  /*originalType=*/Type(), converter);
2588  rewriterImpl.mapping.map(originalValue, castValue);
2589  inverseMapping[castValue].push_back(originalValue);
2590  llvm::erase(inverseMapping[newValue], originalValue);
2591  }
2592  }
2593 }
2594 
2595 //===----------------------------------------------------------------------===//
2596 // Reconcile Unrealized Casts
2597 //===----------------------------------------------------------------------===//
2598 
2601  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2602  SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
2603  castOps.end());
2604  // This set is maintained only if `remainingCastOps` is provided.
2605  DenseSet<Operation *> erasedOps;
2606 
2607  // Helper function that adds all operands to the worklist that are an
2608  // unrealized_conversion_cast op result.
2609  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2610  for (Value v : castOp.getInputs())
2611  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2612  worklist.insert(inputCastOp);
2613  };
2614 
2615  // Helper function that return the unrealized_conversion_cast op that
2616  // defines all inputs of the given op (in the same order). Return "nullptr"
2617  // if there is no such op.
2618  auto getInputCast =
2619  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2620  if (castOp.getInputs().empty())
2621  return {};
2622  auto inputCastOp =
2623  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2624  if (!inputCastOp)
2625  return {};
2626  if (inputCastOp.getOutputs() != castOp.getInputs())
2627  return {};
2628  return inputCastOp;
2629  };
2630 
2631  // Process ops in the worklist bottom-to-top.
2632  while (!worklist.empty()) {
2633  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2634  if (castOp->use_empty()) {
2635  // DCE: If the op has no users, erase it. Add the operands to the
2636  // worklist to find additional DCE opportunities.
2637  enqueueOperands(castOp);
2638  if (remainingCastOps)
2639  erasedOps.insert(castOp.getOperation());
2640  castOp->erase();
2641  continue;
2642  }
2643 
2644  // Traverse the chain of input cast ops to see if an op with the same
2645  // input types can be found.
2646  UnrealizedConversionCastOp nextCast = castOp;
2647  while (nextCast) {
2648  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2649  // Found a cast where the input types match the output types of the
2650  // matched op. We can directly use those inputs and the matched op can
2651  // be removed.
2652  enqueueOperands(castOp);
2653  castOp.replaceAllUsesWith(nextCast.getInputs());
2654  if (remainingCastOps)
2655  erasedOps.insert(castOp.getOperation());
2656  castOp->erase();
2657  break;
2658  }
2659  nextCast = getInputCast(nextCast);
2660  }
2661  }
2662 
2663  if (remainingCastOps)
2664  for (UnrealizedConversionCastOp op : castOps)
2665  if (!erasedOps.contains(op.getOperation()))
2666  remainingCastOps->push_back(op);
2667 }
2668 
2669 //===----------------------------------------------------------------------===//
2670 // Type Conversion
2671 //===----------------------------------------------------------------------===//
2672 
2674  ArrayRef<Type> types) {
2675  assert(!types.empty() && "expected valid types");
2676  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2677  addInputs(types);
2678 }
2679 
2681  assert(!types.empty() &&
2682  "1->0 type remappings don't need to be added explicitly");
2683  argTypes.append(types.begin(), types.end());
2684 }
2685 
2686 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2687  unsigned newInputNo,
2688  unsigned newInputCount) {
2689  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2690  assert(newInputCount != 0 && "expected valid input count");
2691  remappedInputs[origInputNo] =
2692  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2693 }
2694 
2696  Value replacementValue) {
2697  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2698  remappedInputs[origInputNo] =
2699  InputMapping{origInputNo, /*size=*/0, replacementValue};
2700 }
2701 
2703  SmallVectorImpl<Type> &results) const {
2704  {
2705  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2706  std::defer_lock);
2708  cacheReadLock.lock();
2709  auto existingIt = cachedDirectConversions.find(t);
2710  if (existingIt != cachedDirectConversions.end()) {
2711  if (existingIt->second)
2712  results.push_back(existingIt->second);
2713  return success(existingIt->second != nullptr);
2714  }
2715  auto multiIt = cachedMultiConversions.find(t);
2716  if (multiIt != cachedMultiConversions.end()) {
2717  results.append(multiIt->second.begin(), multiIt->second.end());
2718  return success();
2719  }
2720  }
2721  // Walk the added converters in reverse order to apply the most recently
2722  // registered first.
2723  size_t currentCount = results.size();
2724 
2725  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2726  std::defer_lock);
2727 
2728  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2729  if (std::optional<LogicalResult> result = converter(t, results)) {
2731  cacheWriteLock.lock();
2732  if (!succeeded(*result)) {
2733  cachedDirectConversions.try_emplace(t, nullptr);
2734  return failure();
2735  }
2736  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2737  if (newTypes.size() == 1)
2738  cachedDirectConversions.try_emplace(t, newTypes.front());
2739  else
2740  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2741  return success();
2742  }
2743  }
2744  return failure();
2745 }
2746 
2748  // Use the multi-type result version to convert the type.
2749  SmallVector<Type, 1> results;
2750  if (failed(convertType(t, results)))
2751  return nullptr;
2752 
2753  // Check to ensure that only one type was produced.
2754  return results.size() == 1 ? results.front() : nullptr;
2755 }
2756 
2757 LogicalResult
2759  SmallVectorImpl<Type> &results) const {
2760  for (Type type : types)
2761  if (failed(convertType(type, results)))
2762  return failure();
2763  return success();
2764 }
2765 
2766 bool TypeConverter::isLegal(Type type) const {
2767  return convertType(type) == type;
2768 }
2770  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2771 }
2772 
2773 bool TypeConverter::isLegal(Region *region) const {
2774  return llvm::all_of(*region, [this](Block &block) {
2775  return isLegal(block.getArgumentTypes());
2776  });
2777 }
2778 
2779 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2780  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2781 }
2782 
2783 LogicalResult
2785  SignatureConversion &result) const {
2786  // Try to convert the given input type.
2787  SmallVector<Type, 1> convertedTypes;
2788  if (failed(convertType(type, convertedTypes)))
2789  return failure();
2790 
2791  // If this argument is being dropped, there is nothing left to do.
2792  if (convertedTypes.empty())
2793  return success();
2794 
2795  // Otherwise, add the new inputs.
2796  result.addInputs(inputNo, convertedTypes);
2797  return success();
2798 }
2799 LogicalResult
2801  SignatureConversion &result,
2802  unsigned origInputOffset) const {
2803  for (unsigned i = 0, e = types.size(); i != e; ++i)
2804  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2805  return failure();
2806  return success();
2807 }
2808 
2810  Location loc,
2811  Type resultType,
2812  ValueRange inputs) const {
2813  for (const MaterializationCallbackFn &fn :
2814  llvm::reverse(argumentMaterializations))
2815  if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
2816  return *result;
2817  return nullptr;
2818 }
2819 
2821  Location loc, Type resultType,
2822  ValueRange inputs) const {
2823  for (const MaterializationCallbackFn &fn :
2824  llvm::reverse(sourceMaterializations))
2825  if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
2826  return *result;
2827  return nullptr;
2828 }
2829 
2831  Location loc, Type resultType,
2832  ValueRange inputs,
2833  Type originalType) const {
2834  for (const TargetMaterializationCallbackFn &fn :
2835  llvm::reverse(targetMaterializations))
2836  if (std::optional<Value> result =
2837  fn(builder, resultType, inputs, loc, originalType))
2838  return *result;
2839  return nullptr;
2840 }
2841 
2842 std::optional<TypeConverter::SignatureConversion>
2844  SignatureConversion conversion(block->getNumArguments());
2845  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2846  return std::nullopt;
2847  return conversion;
2848 }
2849 
2850 //===----------------------------------------------------------------------===//
2851 // Type attribute conversion
2852 //===----------------------------------------------------------------------===//
2855  return AttributeConversionResult(attr, resultTag);
2856 }
2857 
2860  return AttributeConversionResult(nullptr, naTag);
2861 }
2862 
2865  return AttributeConversionResult(nullptr, abortTag);
2866 }
2867 
2869  return impl.getInt() == resultTag;
2870 }
2871 
2873  return impl.getInt() == naTag;
2874 }
2875 
2877  return impl.getInt() == abortTag;
2878 }
2879 
2881  assert(hasResult() && "Cannot get result from N/A or abort");
2882  return impl.getPointer();
2883 }
2884 
2885 std::optional<Attribute>
2887  for (const TypeAttributeConversionCallbackFn &fn :
2888  llvm::reverse(typeAttributeConversions)) {
2889  AttributeConversionResult res = fn(type, attr);
2890  if (res.hasResult())
2891  return res.getResult();
2892  if (res.isAbort())
2893  return std::nullopt;
2894  }
2895  return std::nullopt;
2896 }
2897 
2898 //===----------------------------------------------------------------------===//
2899 // FunctionOpInterfaceSignatureConversion
2900 //===----------------------------------------------------------------------===//
2901 
2902 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
2903  const TypeConverter &typeConverter,
2904  ConversionPatternRewriter &rewriter) {
2905  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
2906  if (!type)
2907  return failure();
2908 
2909  // Convert the original function types.
2910  TypeConverter::SignatureConversion result(type.getNumInputs());
2911  SmallVector<Type, 1> newResults;
2912  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
2913  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
2914  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
2915  typeConverter, &result)))
2916  return failure();
2917 
2918  // Update the function signature in-place.
2919  auto newType = FunctionType::get(rewriter.getContext(),
2920  result.getConvertedTypes(), newResults);
2921 
2922  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
2923 
2924  return success();
2925 }
2926 
2927 /// Create a default conversion pattern that rewrites the type signature of a
2928 /// FunctionOpInterface op. This only supports ops which use FunctionType to
2929 /// represent their type.
2930 namespace {
2931 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
2932  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
2933  MLIRContext *ctx,
2934  const TypeConverter &converter)
2935  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
2936 
2937  LogicalResult
2938  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
2939  ConversionPatternRewriter &rewriter) const override {
2940  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
2941  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
2942  }
2943 };
2944 
2945 struct AnyFunctionOpInterfaceSignatureConversion
2946  : public OpInterfaceConversionPattern<FunctionOpInterface> {
2948 
2949  LogicalResult
2950  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
2951  ConversionPatternRewriter &rewriter) const override {
2952  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
2953  }
2954 };
2955 } // namespace
2956 
2957 FailureOr<Operation *>
2959  const TypeConverter &converter,
2960  ConversionPatternRewriter &rewriter) {
2961  assert(op && "Invalid op");
2962  Location loc = op->getLoc();
2963  if (converter.isLegal(op))
2964  return rewriter.notifyMatchFailure(loc, "op already legal");
2965 
2966  OperationState newOp(loc, op->getName());
2967  newOp.addOperands(operands);
2968 
2969  SmallVector<Type> newResultTypes;
2970  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
2971  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
2972 
2973  newOp.addTypes(newResultTypes);
2974  newOp.addAttributes(op->getAttrs());
2975  return rewriter.create(newOp);
2976 }
2977 
2979  StringRef functionLikeOpName, RewritePatternSet &patterns,
2980  const TypeConverter &converter) {
2981  patterns.add<FunctionOpInterfaceSignatureConversion>(
2982  functionLikeOpName, patterns.getContext(), converter);
2983 }
2984 
2986  RewritePatternSet &patterns, const TypeConverter &converter) {
2987  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
2988  converter, patterns.getContext());
2989 }
2990 
2991 //===----------------------------------------------------------------------===//
2992 // ConversionTarget
2993 //===----------------------------------------------------------------------===//
2994 
2996  LegalizationAction action) {
2997  legalOperations[op].action = action;
2998 }
2999 
3001  LegalizationAction action) {
3002  for (StringRef dialect : dialectNames)
3003  legalDialects[dialect] = action;
3004 }
3005 
3007  -> std::optional<LegalizationAction> {
3008  std::optional<LegalizationInfo> info = getOpInfo(op);
3009  return info ? info->action : std::optional<LegalizationAction>();
3010 }
3011 
3013  -> std::optional<LegalOpDetails> {
3014  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3015  if (!info)
3016  return std::nullopt;
3017 
3018  // Returns true if this operation instance is known to be legal.
3019  auto isOpLegal = [&] {
3020  // Handle dynamic legality either with the provided legality function.
3021  if (info->action == LegalizationAction::Dynamic) {
3022  std::optional<bool> result = info->legalityFn(op);
3023  if (result)
3024  return *result;
3025  }
3026 
3027  // Otherwise, the operation is only legal if it was marked 'Legal'.
3028  return info->action == LegalizationAction::Legal;
3029  };
3030  if (!isOpLegal())
3031  return std::nullopt;
3032 
3033  // This operation is legal, compute any additional legality information.
3034  LegalOpDetails legalityDetails;
3035  if (info->isRecursivelyLegal) {
3036  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3037  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3038  legalityDetails.isRecursivelyLegal =
3039  legalityFnIt->second(op).value_or(true);
3040  } else {
3041  legalityDetails.isRecursivelyLegal = true;
3042  }
3043  }
3044  return legalityDetails;
3045 }
3046 
3048  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3049  if (!info)
3050  return false;
3051 
3052  if (info->action == LegalizationAction::Dynamic) {
3053  std::optional<bool> result = info->legalityFn(op);
3054  if (!result)
3055  return false;
3056 
3057  return !(*result);
3058  }
3059 
3060  return info->action == LegalizationAction::Illegal;
3061 }
3062 
3066  if (!oldCallback)
3067  return newCallback;
3068 
3069  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3070  Operation *op) -> std::optional<bool> {
3071  if (std::optional<bool> result = newCl(op))
3072  return *result;
3073 
3074  return oldCl(op);
3075  };
3076  return chain;
3077 }
3078 
3079 void ConversionTarget::setLegalityCallback(
3080  OperationName name, const DynamicLegalityCallbackFn &callback) {
3081  assert(callback && "expected valid legality callback");
3082  auto *infoIt = legalOperations.find(name);
3083  assert(infoIt != legalOperations.end() &&
3084  infoIt->second.action == LegalizationAction::Dynamic &&
3085  "expected operation to already be marked as dynamically legal");
3086  infoIt->second.legalityFn =
3087  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3088 }
3089 
3091  OperationName name, const DynamicLegalityCallbackFn &callback) {
3092  auto *infoIt = legalOperations.find(name);
3093  assert(infoIt != legalOperations.end() &&
3094  infoIt->second.action != LegalizationAction::Illegal &&
3095  "expected operation to already be marked as legal");
3096  infoIt->second.isRecursivelyLegal = true;
3097  if (callback)
3098  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3099  std::move(opRecursiveLegalityFns[name]), callback);
3100  else
3101  opRecursiveLegalityFns.erase(name);
3102 }
3103 
3104 void ConversionTarget::setLegalityCallback(
3105  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3106  assert(callback && "expected valid legality callback");
3107  for (StringRef dialect : dialects)
3108  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3109  std::move(dialectLegalityFns[dialect]), callback);
3110 }
3111 
3112 void ConversionTarget::setLegalityCallback(
3113  const DynamicLegalityCallbackFn &callback) {
3114  assert(callback && "expected valid legality callback");
3115  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3116 }
3117 
3118 auto ConversionTarget::getOpInfo(OperationName op) const
3119  -> std::optional<LegalizationInfo> {
3120  // Check for info for this specific operation.
3121  const auto *it = legalOperations.find(op);
3122  if (it != legalOperations.end())
3123  return it->second;
3124  // Check for info for the parent dialect.
3125  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3126  if (dialectIt != legalDialects.end()) {
3127  DynamicLegalityCallbackFn callback;
3128  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3129  if (dialectFn != dialectLegalityFns.end())
3130  callback = dialectFn->second;
3131  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3132  callback};
3133  }
3134  // Otherwise, check if we mark unknown operations as dynamic.
3135  if (unknownLegalityFn)
3136  return LegalizationInfo{LegalizationAction::Dynamic,
3137  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3138  return std::nullopt;
3139 }
3140 
3141 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3142 //===----------------------------------------------------------------------===//
3143 // PDL Configuration
3144 //===----------------------------------------------------------------------===//
3145 
3147  auto &rewriterImpl =
3148  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3149  rewriterImpl.currentTypeConverter = getTypeConverter();
3150 }
3151 
3153  auto &rewriterImpl =
3154  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3155  rewriterImpl.currentTypeConverter = nullptr;
3156 }
3157 
3158 /// Remap the given value using the rewriter and the type converter in the
3159 /// provided config.
3160 static FailureOr<SmallVector<Value>>
3162  SmallVector<Value> mappedValues;
3163  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3164  return failure();
3165  return std::move(mappedValues);
3166 }
3167 
3169  patterns.getPDLPatterns().registerRewriteFunction(
3170  "convertValue",
3171  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3172  auto results = pdllConvertValues(
3173  static_cast<ConversionPatternRewriter &>(rewriter), value);
3174  if (failed(results))
3175  return failure();
3176  return results->front();
3177  });
3178  patterns.getPDLPatterns().registerRewriteFunction(
3179  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3180  return pdllConvertValues(
3181  static_cast<ConversionPatternRewriter &>(rewriter), values);
3182  });
3183  patterns.getPDLPatterns().registerRewriteFunction(
3184  "convertType",
3185  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3186  auto &rewriterImpl =
3187  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3188  if (const TypeConverter *converter =
3189  rewriterImpl.currentTypeConverter) {
3190  if (Type newType = converter->convertType(type))
3191  return newType;
3192  return failure();
3193  }
3194  return type;
3195  });
3196  patterns.getPDLPatterns().registerRewriteFunction(
3197  "convertTypes",
3198  [](PatternRewriter &rewriter,
3199  TypeRange types) -> FailureOr<SmallVector<Type>> {
3200  auto &rewriterImpl =
3201  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3202  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3203  if (!converter)
3204  return SmallVector<Type>(types);
3205 
3206  SmallVector<Type> remappedTypes;
3207  if (failed(converter->convertTypes(types, remappedTypes)))
3208  return failure();
3209  return std::move(remappedTypes);
3210  });
3211 }
3212 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3213 
3214 //===----------------------------------------------------------------------===//
3215 // Op Conversion Entry Points
3216 //===----------------------------------------------------------------------===//
3217 
3218 //===----------------------------------------------------------------------===//
3219 // Partial Conversion
3220 
3222  ArrayRef<Operation *> ops, const ConversionTarget &target,
3223  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3224  OperationConverter opConverter(target, patterns, config,
3225  OpConversionMode::Partial);
3226  return opConverter.convertOperations(ops);
3227 }
3228 LogicalResult
3230  const FrozenRewritePatternSet &patterns,
3231  ConversionConfig config) {
3232  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3233 }
3234 
3235 //===----------------------------------------------------------------------===//
3236 // Full Conversion
3237 
3239  const ConversionTarget &target,
3240  const FrozenRewritePatternSet &patterns,
3241  ConversionConfig config) {
3242  OperationConverter opConverter(target, patterns, config,
3243  OpConversionMode::Full);
3244  return opConverter.convertOperations(ops);
3245 }
3247  const ConversionTarget &target,
3248  const FrozenRewritePatternSet &patterns,
3249  ConversionConfig config) {
3250  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3251 }
3252 
3253 //===----------------------------------------------------------------------===//
3254 // Analysis Conversion
3255 
3256 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3257 /// op is a top-level module op (which is expected to be isolated from above),
3258 /// return that op.
3260  // Check if there is a top-level operation within `ops`. If so, return that
3261  // op.
3262  for (Operation *op : ops) {
3263  if (!op->getParentOp()) {
3264 #ifndef NDEBUG
3265  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3266  "expected top-level op to be isolated from above");
3267  for (Operation *other : ops)
3268  assert(op->isAncestor(other) &&
3269  "expected ops to have a common ancestor");
3270 #endif // NDEBUG
3271  return op;
3272  }
3273  }
3274 
3275  // No top-level op. Find a common ancestor.
3276  Operation *commonAncestor =
3277  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3278  for (Operation *op : ops.drop_front()) {
3279  while (!commonAncestor->isProperAncestor(op)) {
3280  commonAncestor =
3282  assert(commonAncestor &&
3283  "expected to find a common isolated from above ancestor");
3284  }
3285  }
3286 
3287  return commonAncestor;
3288 }
3289 
3292  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3293 #ifndef NDEBUG
3294  if (config.legalizableOps)
3295  assert(config.legalizableOps->empty() && "expected empty set");
3296 #endif // NDEBUG
3297 
3298  // Clone closted common ancestor that is isolated from above.
3299  Operation *commonAncestor = findCommonAncestor(ops);
3300  IRMapping mapping;
3301  Operation *clonedAncestor = commonAncestor->clone(mapping);
3302  // Compute inverse IR mapping.
3303  DenseMap<Operation *, Operation *> inverseOperationMap;
3304  for (auto &it : mapping.getOperationMap())
3305  inverseOperationMap[it.second] = it.first;
3306 
3307  // Convert the cloned operations. The original IR will remain unchanged.
3308  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3309  ops, [&](Operation *op) { return mapping.lookup(op); });
3310  OperationConverter opConverter(target, patterns, config,
3311  OpConversionMode::Analysis);
3312  LogicalResult status = opConverter.convertOperations(opsToConvert);
3313 
3314  // Remap `legalizableOps`, so that they point to the original ops and not the
3315  // cloned ops.
3316  if (config.legalizableOps) {
3317  DenseSet<Operation *> originalLegalizableOps;
3318  for (Operation *op : *config.legalizableOps)
3319  originalLegalizableOps.insert(inverseOperationMap[op]);
3320  *config.legalizableOps = std::move(originalLegalizableOps);
3321  }
3322 
3323  // Erase the cloned IR.
3324  clonedAncestor->erase();
3325  return status;
3326 }
3327 
3328 LogicalResult
3330  const FrozenRewritePatternSet &patterns,
3331  ConversionConfig config) {
3332  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3333 }
static std::pair< ValueRange, const TypeConverter * > getReplacedValues(IRRewrite *rewrite)
Helper function that returns the replaced values and the type converter if the given rewrite object i...
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static bool hasRewrite(R &&rewrites, Operation *op)
Return "true" if there is an operation rewrite that matches the specified rewrite type and operation ...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
bool empty()
Definition: Block.h:146
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:93
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator end()
Definition: Block.h:142
iterator begin()
Definition: Block.h:141
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:335
Block::iterator getPoint() const
Definition: Builders.h:348
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:345
Block * getBlock() const
Definition: Builders.h:347
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:328
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:512
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:605
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:892
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
MLIRContext * getContext() const
Definition: PatternMatch.h:823
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_iterator user_end() const
Definition: Value.h:227
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition: Argument.h:64
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target/ argument materializations t...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a single value that remaps an existing signature input...
A rewriter that keeps track of erased ops and blocks.
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ValueRange newValues)
Notifies that an op is about to be replaced with the given values.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
SingleEraseRewriter eraseRewriter
A rewriter that keeps track of ops/block that were already erased and skips duplicate op/block erasur...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
Value buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueRange inputs, Type outputType, Type originalType, const TypeConverter *converter)
Build an unresolved materialization operation given an output type and set of input operands.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.