MLIR  20.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Iterators.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::detail;
29 
30 #define DEBUG_TYPE "dialect-conversion"
31 
32 /// A utility function to log a successful result for the given reason.
33 template <typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
35  LLVM_DEBUG({
36  os.unindent();
37  os.startLine() << "} -> SUCCESS";
38  if (!fmt.empty())
39  os.getOStream() << " : "
40  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41  os.getOStream() << "\n";
42  });
43 }
44 
45 /// A utility function to log a failure result for the given reason.
46 template <typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
48  LLVM_DEBUG({
49  os.unindent();
50  os.startLine() << "} -> FAILURE : "
51  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
52  << "\n";
53  });
54 }
55 
56 /// Helper function that computes an insertion point where the given value is
57 /// defined and can be used without a dominance violation.
59  Block *insertBlock = value.getParentBlock();
60  Block::iterator insertPt = insertBlock->begin();
61  if (OpResult inputRes = dyn_cast<OpResult>(value))
62  insertPt = ++inputRes.getOwner()->getIterator();
63  return OpBuilder::InsertPoint(insertBlock, insertPt);
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // ConversionValueMapping
68 //===----------------------------------------------------------------------===//
69 
70 namespace {
71 /// This class wraps a IRMapping to provide recursive lookup
72 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
73 struct ConversionValueMapping {
74  /// Lookup the most recently mapped value with the desired type in the
75  /// mapping.
76  ///
77  /// Special cases:
78  /// - If the desired type is "null", simply return the most recently mapped
79  /// value.
80  /// - If there is no mapping to the desired type, also return the most
81  /// recently mapped value.
82  /// - If there is no mapping for the given value at all, return the given
83  /// value.
84  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
85 
86  /// Lookup a mapped value within the map, or return null if a mapping does not
87  /// exist. If a mapping exists, this follows the same behavior of
88  /// `lookupOrDefault`.
89  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
90 
91  /// Map a value to the one provided.
92  void map(Value oldVal, Value newVal) {
93  LLVM_DEBUG({
94  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
95  assert(it != oldVal && "inserting cyclic mapping");
96  });
97  mapping.map(oldVal, newVal);
98  }
99 
100  /// Try to map a value to the one provided. Returns false if a transitive
101  /// mapping from the new value to the old value already exists, true if the
102  /// map was updated.
103  bool tryMap(Value oldVal, Value newVal);
104 
105  /// Drop the last mapping for the given value.
106  void erase(Value value) { mapping.erase(value); }
107 
108  /// Returns the inverse raw value mapping (without recursive query support).
109  DenseMap<Value, SmallVector<Value>> getInverse() const {
111  for (auto &it : mapping.getValueMap())
112  inverse[it.second].push_back(it.first);
113  return inverse;
114  }
115 
116 private:
117  /// Current value mappings.
118  IRMapping mapping;
119 };
120 } // namespace
121 
122 Value ConversionValueMapping::lookupOrDefault(Value from,
123  Type desiredType) const {
124  // Try to find the deepest value that has the desired type. If there is no
125  // such value, simply return the deepest value.
126  Value desiredValue;
127  do {
128  if (!desiredType || from.getType() == desiredType)
129  desiredValue = from;
130 
131  Value mappedValue = mapping.lookupOrNull(from);
132  if (!mappedValue)
133  break;
134  from = mappedValue;
135  } while (true);
136 
137  // If the desired value was found use it, otherwise default to the leaf value.
138  return desiredValue ? desiredValue : from;
139 }
140 
141 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
142  Value result = lookupOrDefault(from, desiredType);
143  if (result == from || (desiredType && result.getType() != desiredType))
144  return nullptr;
145  return result;
146 }
147 
148 bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
149  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
150  if (it == oldVal)
151  return false;
152  map(oldVal, newVal);
153  return true;
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // Rewriter and Translation State
158 //===----------------------------------------------------------------------===//
159 namespace {
160 /// This class contains a snapshot of the current conversion rewriter state.
161 /// This is useful when saving and undoing a set of rewrites.
162 struct RewriterState {
163  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
164  unsigned numReplacedOps)
165  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
166  numReplacedOps(numReplacedOps) {}
167 
168  /// The current number of rewrites performed.
169  unsigned numRewrites;
170 
171  /// The current number of ignored operations.
172  unsigned numIgnoredOperations;
173 
174  /// The current number of replaced ops that are scheduled for erasure.
175  unsigned numReplacedOps;
176 };
177 
178 //===----------------------------------------------------------------------===//
179 // IR rewrites
180 //===----------------------------------------------------------------------===//
181 
182 /// An IR rewrite that can be committed (upon success) or rolled back (upon
183 /// failure).
184 ///
185 /// The dialect conversion keeps track of IR modifications (requested by the
186 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
187 /// are directly applied to the IR as the rewriter API is used, some are applied
188 /// partially, and some are delayed until the `IRRewrite` objects are committed.
189 class IRRewrite {
190 public:
191  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
192  /// Enum values are ordered, so that they can be used in `classof`: first all
193  /// block rewrites, then all operation rewrites.
194  enum class Kind {
195  // Block rewrites
196  CreateBlock,
197  EraseBlock,
198  InlineBlock,
199  MoveBlock,
200  BlockTypeConversion,
201  ReplaceBlockArg,
202  // Operation rewrites
203  MoveOperation,
204  ModifyOperation,
205  ReplaceOperation,
206  CreateOperation,
207  UnresolvedMaterialization
208  };
209 
210  virtual ~IRRewrite() = default;
211 
212  /// Roll back the rewrite. Operations may be erased during rollback.
213  virtual void rollback() = 0;
214 
215  /// Commit the rewrite. At this point, it is certain that the dialect
216  /// conversion will succeed. All IR modifications, except for operation/block
217  /// erasure, must be performed through the given rewriter.
218  ///
219  /// Instead of erasing operations/blocks, they should merely be unlinked
220  /// commit phase and finally be erased during the cleanup phase. This is
221  /// because internal dialect conversion state (such as `mapping`) may still
222  /// be using them.
223  ///
224  /// Any IR modification that was already performed before the commit phase
225  /// (e.g., insertion of an op) must be communicated to the listener that may
226  /// be attached to the given rewriter.
227  virtual void commit(RewriterBase &rewriter) {}
228 
229  /// Cleanup operations/blocks. Cleanup is called after commit.
230  virtual void cleanup(RewriterBase &rewriter) {}
231 
232  Kind getKind() const { return kind; }
233 
234  static bool classof(const IRRewrite *rewrite) { return true; }
235 
236 protected:
237  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
238  : kind(kind), rewriterImpl(rewriterImpl) {}
239 
240  const ConversionConfig &getConfig() const;
241 
242  const Kind kind;
243  ConversionPatternRewriterImpl &rewriterImpl;
244 };
245 
246 /// A block rewrite.
247 class BlockRewrite : public IRRewrite {
248 public:
249  /// Return the block that this rewrite operates on.
250  Block *getBlock() const { return block; }
251 
252  static bool classof(const IRRewrite *rewrite) {
253  return rewrite->getKind() >= Kind::CreateBlock &&
254  rewrite->getKind() <= Kind::ReplaceBlockArg;
255  }
256 
257 protected:
258  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
259  Block *block)
260  : IRRewrite(kind, rewriterImpl), block(block) {}
261 
262  // The block that this rewrite operates on.
263  Block *block;
264 };
265 
266 /// Creation of a block. Block creations are immediately reflected in the IR.
267 /// There is no extra work to commit the rewrite. During rollback, the newly
268 /// created block is erased.
269 class CreateBlockRewrite : public BlockRewrite {
270 public:
271  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
272  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
273 
274  static bool classof(const IRRewrite *rewrite) {
275  return rewrite->getKind() == Kind::CreateBlock;
276  }
277 
278  void commit(RewriterBase &rewriter) override {
279  // The block was already created and inserted. Just inform the listener.
280  if (auto *listener = rewriter.getListener())
281  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
282  }
283 
284  void rollback() override {
285  // Unlink all of the operations within this block, they will be deleted
286  // separately.
287  auto &blockOps = block->getOperations();
288  while (!blockOps.empty())
289  blockOps.remove(blockOps.begin());
290  block->dropAllUses();
291  if (block->getParent())
292  block->erase();
293  else
294  delete block;
295  }
296 };
297 
298 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
299 /// blocks are immediately unlinked, but only erased during cleanup. This makes
300 /// it easier to rollback a block erasure: the block is simply inserted into its
301 /// original location.
302 class EraseBlockRewrite : public BlockRewrite {
303 public:
304  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
305  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
306  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
307 
308  static bool classof(const IRRewrite *rewrite) {
309  return rewrite->getKind() == Kind::EraseBlock;
310  }
311 
312  ~EraseBlockRewrite() override {
313  assert(!block &&
314  "rewrite was neither rolled back nor committed/cleaned up");
315  }
316 
317  void rollback() override {
318  // The block (owned by this rewrite) was not actually erased yet. It was
319  // just unlinked. Put it back into its original position.
320  assert(block && "expected block");
321  auto &blockList = region->getBlocks();
322  Region::iterator before = insertBeforeBlock
323  ? Region::iterator(insertBeforeBlock)
324  : blockList.end();
325  blockList.insert(before, block);
326  block = nullptr;
327  }
328 
329  void commit(RewriterBase &rewriter) override {
330  // Erase the block.
331  assert(block && "expected block");
332  assert(block->empty() && "expected empty block");
333 
334  // Notify the listener that the block is about to be erased.
335  if (auto *listener =
336  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
337  listener->notifyBlockErased(block);
338  }
339 
340  void cleanup(RewriterBase &rewriter) override {
341  // Erase the block.
342  block->dropAllDefinedValueUses();
343  delete block;
344  block = nullptr;
345  }
346 
347 private:
348  // The region in which this block was previously contained.
349  Region *region;
350 
351  // The original successor of this block before it was unlinked. "nullptr" if
352  // this block was the only block in the region.
353  Block *insertBeforeBlock;
354 };
355 
356 /// Inlining of a block. This rewrite is immediately reflected in the IR.
357 /// Note: This rewrite represents only the inlining of the operations. The
358 /// erasure of the inlined block is a separate rewrite.
359 class InlineBlockRewrite : public BlockRewrite {
360 public:
361  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
362  Block *sourceBlock, Block::iterator before)
363  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
364  sourceBlock(sourceBlock),
365  firstInlinedInst(sourceBlock->empty() ? nullptr
366  : &sourceBlock->front()),
367  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
368  // If a listener is attached to the dialect conversion, ops must be moved
369  // one-by-one. When they are moved in bulk, notifications cannot be sent
370  // because the ops that used to be in the source block at the time of the
371  // inlining (before the "commit" phase) are unknown at the time when
372  // notifications are sent (which is during the "commit" phase).
373  assert(!getConfig().listener &&
374  "InlineBlockRewrite not supported if listener is attached");
375  }
376 
377  static bool classof(const IRRewrite *rewrite) {
378  return rewrite->getKind() == Kind::InlineBlock;
379  }
380 
381  void rollback() override {
382  // Put the operations from the destination block (owned by the rewrite)
383  // back into the source block.
384  if (firstInlinedInst) {
385  assert(lastInlinedInst && "expected operation");
386  sourceBlock->getOperations().splice(sourceBlock->begin(),
387  block->getOperations(),
388  Block::iterator(firstInlinedInst),
389  ++Block::iterator(lastInlinedInst));
390  }
391  }
392 
393 private:
394  // The block that originally contained the operations.
395  Block *sourceBlock;
396 
397  // The first inlined operation.
398  Operation *firstInlinedInst;
399 
400  // The last inlined operation.
401  Operation *lastInlinedInst;
402 };
403 
404 /// Moving of a block. This rewrite is immediately reflected in the IR.
405 class MoveBlockRewrite : public BlockRewrite {
406 public:
407  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
408  Region *region, Block *insertBeforeBlock)
409  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
410  insertBeforeBlock(insertBeforeBlock) {}
411 
412  static bool classof(const IRRewrite *rewrite) {
413  return rewrite->getKind() == Kind::MoveBlock;
414  }
415 
416  void commit(RewriterBase &rewriter) override {
417  // The block was already moved. Just inform the listener.
418  if (auto *listener = rewriter.getListener()) {
419  // Note: `previousIt` cannot be passed because this is a delayed
420  // notification and iterators into past IR state cannot be represented.
421  listener->notifyBlockInserted(block, /*previous=*/region,
422  /*previousIt=*/{});
423  }
424  }
425 
426  void rollback() override {
427  // Move the block back to its original position.
428  Region::iterator before =
429  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
430  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
431  }
432 
433 private:
434  // The region in which this block was previously contained.
435  Region *region;
436 
437  // The original successor of this block before it was moved. "nullptr" if
438  // this block was the only block in the region.
439  Block *insertBeforeBlock;
440 };
441 
442 /// Block type conversion. This rewrite is partially reflected in the IR.
443 class BlockTypeConversionRewrite : public BlockRewrite {
444 public:
445  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
446  Block *block, Block *origBlock,
447  const TypeConverter *converter)
448  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
449  origBlock(origBlock), converter(converter) {}
450 
451  static bool classof(const IRRewrite *rewrite) {
452  return rewrite->getKind() == Kind::BlockTypeConversion;
453  }
454 
455  Block *getOrigBlock() const { return origBlock; }
456 
457  const TypeConverter *getConverter() const { return converter; }
458 
459  void commit(RewriterBase &rewriter) override;
460 
461  void rollback() override;
462 
463 private:
464  /// The original block that was requested to have its signature converted.
465  Block *origBlock;
466 
467  /// The type converter used to convert the arguments.
468  const TypeConverter *converter;
469 };
470 
471 /// Replacing a block argument. This rewrite is not immediately reflected in the
472 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
473 /// until the rewrite is committed.
474 class ReplaceBlockArgRewrite : public BlockRewrite {
475 public:
476  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
477  Block *block, BlockArgument arg)
478  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
479 
480  static bool classof(const IRRewrite *rewrite) {
481  return rewrite->getKind() == Kind::ReplaceBlockArg;
482  }
483 
484  void commit(RewriterBase &rewriter) override;
485 
486  void rollback() override;
487 
488 private:
489  BlockArgument arg;
490 };
491 
492 /// An operation rewrite.
493 class OperationRewrite : public IRRewrite {
494 public:
495  /// Return the operation that this rewrite operates on.
496  Operation *getOperation() const { return op; }
497 
498  static bool classof(const IRRewrite *rewrite) {
499  return rewrite->getKind() >= Kind::MoveOperation &&
500  rewrite->getKind() <= Kind::UnresolvedMaterialization;
501  }
502 
503 protected:
504  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
505  Operation *op)
506  : IRRewrite(kind, rewriterImpl), op(op) {}
507 
508  // The operation that this rewrite operates on.
509  Operation *op;
510 };
511 
512 /// Moving of an operation. This rewrite is immediately reflected in the IR.
513 class MoveOperationRewrite : public OperationRewrite {
514 public:
515  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
516  Operation *op, Block *block, Operation *insertBeforeOp)
517  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
518  insertBeforeOp(insertBeforeOp) {}
519 
520  static bool classof(const IRRewrite *rewrite) {
521  return rewrite->getKind() == Kind::MoveOperation;
522  }
523 
524  void commit(RewriterBase &rewriter) override {
525  // The operation was already moved. Just inform the listener.
526  if (auto *listener = rewriter.getListener()) {
527  // Note: `previousIt` cannot be passed because this is a delayed
528  // notification and iterators into past IR state cannot be represented.
529  listener->notifyOperationInserted(
530  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
531  /*insertPt=*/{}));
532  }
533  }
534 
535  void rollback() override {
536  // Move the operation back to its original position.
537  Block::iterator before =
538  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
539  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
540  }
541 
542 private:
543  // The block in which this operation was previously contained.
544  Block *block;
545 
546  // The original successor of this operation before it was moved. "nullptr"
547  // if this operation was the only operation in the region.
548  Operation *insertBeforeOp;
549 };
550 
551 /// In-place modification of an op. This rewrite is immediately reflected in
552 /// the IR. The previous state of the operation is stored in this object.
553 class ModifyOperationRewrite : public OperationRewrite {
554 public:
555  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
556  Operation *op)
557  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
558  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
559  operands(op->operand_begin(), op->operand_end()),
560  successors(op->successor_begin(), op->successor_end()) {
561  if (OpaqueProperties prop = op->getPropertiesStorage()) {
562  // Make a copy of the properties.
563  propertiesStorage = operator new(op->getPropertiesStorageSize());
564  OpaqueProperties propCopy(propertiesStorage);
565  name.initOpProperties(propCopy, /*init=*/prop);
566  }
567  }
568 
569  static bool classof(const IRRewrite *rewrite) {
570  return rewrite->getKind() == Kind::ModifyOperation;
571  }
572 
573  ~ModifyOperationRewrite() override {
574  assert(!propertiesStorage &&
575  "rewrite was neither committed nor rolled back");
576  }
577 
578  void commit(RewriterBase &rewriter) override {
579  // Notify the listener that the operation was modified in-place.
580  if (auto *listener =
581  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
582  listener->notifyOperationModified(op);
583 
584  if (propertiesStorage) {
585  OpaqueProperties propCopy(propertiesStorage);
586  // Note: The operation may have been erased in the mean time, so
587  // OperationName must be stored in this object.
588  name.destroyOpProperties(propCopy);
589  operator delete(propertiesStorage);
590  propertiesStorage = nullptr;
591  }
592  }
593 
594  void rollback() override {
595  op->setLoc(loc);
596  op->setAttrs(attrs);
597  op->setOperands(operands);
598  for (const auto &it : llvm::enumerate(successors))
599  op->setSuccessor(it.value(), it.index());
600  if (propertiesStorage) {
601  OpaqueProperties propCopy(propertiesStorage);
602  op->copyProperties(propCopy);
603  name.destroyOpProperties(propCopy);
604  operator delete(propertiesStorage);
605  propertiesStorage = nullptr;
606  }
607  }
608 
609 private:
610  OperationName name;
611  LocationAttr loc;
612  DictionaryAttr attrs;
613  SmallVector<Value, 8> operands;
614  SmallVector<Block *, 2> successors;
615  void *propertiesStorage = nullptr;
616 };
617 
618 /// Replacing an operation. Erasing an operation is treated as a special case
619 /// with "null" replacements. This rewrite is not immediately reflected in the
620 /// IR. An internal IR mapping is updated, but values are not replaced and the
621 /// original op is not erased until the rewrite is committed.
622 class ReplaceOperationRewrite : public OperationRewrite {
623 public:
624  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
625  Operation *op, const TypeConverter *converter)
626  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
627  converter(converter) {}
628 
629  static bool classof(const IRRewrite *rewrite) {
630  return rewrite->getKind() == Kind::ReplaceOperation;
631  }
632 
633  void commit(RewriterBase &rewriter) override;
634 
635  void rollback() override;
636 
637  void cleanup(RewriterBase &rewriter) override;
638 
639  const TypeConverter *getConverter() const { return converter; }
640 
641 private:
642  /// An optional type converter that can be used to materialize conversions
643  /// between the new and old values if necessary.
644  const TypeConverter *converter;
645 };
646 
647 class CreateOperationRewrite : public OperationRewrite {
648 public:
649  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
650  Operation *op)
651  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
652 
653  static bool classof(const IRRewrite *rewrite) {
654  return rewrite->getKind() == Kind::CreateOperation;
655  }
656 
657  void commit(RewriterBase &rewriter) override {
658  // The operation was already created and inserted. Just inform the listener.
659  if (auto *listener = rewriter.getListener())
660  listener->notifyOperationInserted(op, /*previous=*/{});
661  }
662 
663  void rollback() override;
664 };
665 
666 /// The type of materialization.
667 enum MaterializationKind {
668  /// This materialization materializes a conversion for an illegal block
669  /// argument type, to the original one.
670  Argument,
671 
672  /// This materialization materializes a conversion from an illegal type to a
673  /// legal one.
674  Target,
675 
676  /// This materialization materializes a conversion from a legal type back to
677  /// an illegal one.
678  Source
679 };
680 
681 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
682 /// op. Unresolved materializations are erased at the end of the dialect
683 /// conversion.
684 class UnresolvedMaterializationRewrite : public OperationRewrite {
685 public:
686  UnresolvedMaterializationRewrite(
687  ConversionPatternRewriterImpl &rewriterImpl,
688  UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
689  MaterializationKind kind = MaterializationKind::Target);
690 
691  static bool classof(const IRRewrite *rewrite) {
692  return rewrite->getKind() == Kind::UnresolvedMaterialization;
693  }
694 
695  void rollback() override;
696 
697  UnrealizedConversionCastOp getOperation() const {
698  return cast<UnrealizedConversionCastOp>(op);
699  }
700 
701  /// Return the type converter of this materialization (which may be null).
702  const TypeConverter *getConverter() const {
703  return converterAndKind.getPointer();
704  }
705 
706  /// Return the kind of this materialization.
707  MaterializationKind getMaterializationKind() const {
708  return converterAndKind.getInt();
709  }
710 
711 private:
712  /// The corresponding type converter to use when resolving this
713  /// materialization, and the kind of this materialization.
714  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
715  converterAndKind;
716 };
717 } // namespace
718 
719 /// Return "true" if there is an operation rewrite that matches the specified
720 /// rewrite type and operation among the given rewrites.
721 template <typename RewriteTy, typename R>
722 static bool hasRewrite(R &&rewrites, Operation *op) {
723  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
724  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
725  return rewriteTy && rewriteTy->getOperation() == op;
726  });
727 }
728 
729 //===----------------------------------------------------------------------===//
730 // ConversionPatternRewriterImpl
731 //===----------------------------------------------------------------------===//
732 namespace mlir {
733 namespace detail {
736  const ConversionConfig &config)
737  : context(ctx), eraseRewriter(ctx), config(config) {}
738 
739  //===--------------------------------------------------------------------===//
740  // State Management
741  //===--------------------------------------------------------------------===//
742 
743  /// Return the current state of the rewriter.
744  RewriterState getCurrentState();
745 
746  /// Apply all requested operation rewrites. This method is invoked when the
747  /// conversion process succeeds.
748  void applyRewrites();
749 
750  /// Reset the state of the rewriter to a previously saved point.
751  void resetState(RewriterState state);
752 
753  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
754  /// failure.
755  template <typename RewriteTy, typename... Args>
756  void appendRewrite(Args &&...args) {
757  rewrites.push_back(
758  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
759  }
760 
761  /// Undo the rewrites (motions, splits) one by one in reverse order until
762  /// "numRewritesToKeep" rewrites remains.
763  void undoRewrites(unsigned numRewritesToKeep = 0);
764 
765  /// Remap the given values to those with potentially different types. Returns
766  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
767  /// is the tag used when describing a value within a diagnostic, e.g.
768  /// "operand".
769  LogicalResult remapValues(StringRef valueDiagTag,
770  std::optional<Location> inputLoc,
771  PatternRewriter &rewriter, ValueRange values,
772  SmallVectorImpl<Value> &remapped);
773 
774  /// Return "true" if the given operation is ignored, and does not need to be
775  /// converted.
776  bool isOpIgnored(Operation *op) const;
777 
778  /// Return "true" if the given operation was replaced or erased.
779  bool wasOpReplaced(Operation *op) const;
780 
781  //===--------------------------------------------------------------------===//
782  // Type Conversion
783  //===--------------------------------------------------------------------===//
784 
785  /// Convert the types of block arguments within the given region.
786  FailureOr<Block *>
787  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
788  const TypeConverter &converter,
789  TypeConverter::SignatureConversion *entryConversion);
790 
791  /// Apply the given signature conversion on the given block. The new block
792  /// containing the updated signature is returned. If no conversions were
793  /// necessary, e.g. if the block has no arguments, `block` is returned.
794  /// `converter` is used to generate any necessary cast operations that
795  /// translate between the origin argument types and those specified in the
796  /// signature conversion.
797  Block *applySignatureConversion(
798  ConversionPatternRewriter &rewriter, Block *block,
799  const TypeConverter *converter,
800  TypeConverter::SignatureConversion &signatureConversion);
801 
802  //===--------------------------------------------------------------------===//
803  // Materializations
804  //===--------------------------------------------------------------------===//
805 
806  /// Build an unresolved materialization operation given an output type and set
807  /// of input operands.
808  Value buildUnresolvedMaterialization(MaterializationKind kind,
810  ValueRange inputs, Type outputType,
811  const TypeConverter *converter);
812 
813  //===--------------------------------------------------------------------===//
814  // Rewriter Notification Hooks
815  //===--------------------------------------------------------------------===//
816 
817  //// Notifies that an op was inserted.
818  void notifyOperationInserted(Operation *op,
819  OpBuilder::InsertPoint previous) override;
820 
821  /// Notifies that an op is about to be replaced with the given values.
822  void notifyOpReplaced(Operation *op, ValueRange newValues);
823 
824  /// Notifies that a block is about to be erased.
825  void notifyBlockIsBeingErased(Block *block);
826 
827  /// Notifies that a block was inserted.
828  void notifyBlockInserted(Block *block, Region *previous,
829  Region::iterator previousIt) override;
830 
831  /// Notifies that a block is being inlined into another block.
832  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
833  Block::iterator before);
834 
835  /// Notifies that a pattern match failed for the given reason.
836  void
837  notifyMatchFailure(Location loc,
838  function_ref<void(Diagnostic &)> reasonCallback) override;
839 
840  //===--------------------------------------------------------------------===//
841  // IR Erasure
842  //===--------------------------------------------------------------------===//
843 
844  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
845  /// operation or block is erased multiple times. This rewriter assumes that
846  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
848  public:
850  : RewriterBase(context, /*listener=*/this) {}
851 
852  /// Erase the given op (unless it was already erased).
853  void eraseOp(Operation *op) override {
854  if (wasErased(op))
855  return;
856  op->dropAllUses();
858  }
859 
860  /// Erase the given block (unless it was already erased).
861  void eraseBlock(Block *block) override {
862  if (wasErased(block))
863  return;
864  assert(block->empty() && "expected empty block");
865  block->dropAllDefinedValueUses();
867  }
868 
869  bool wasErased(void *ptr) const { return erased.contains(ptr); }
870 
871  void notifyOperationErased(Operation *op) override { erased.insert(op); }
872 
873  void notifyBlockErased(Block *block) override { erased.insert(block); }
874 
875  private:
876  /// Pointers to all erased operations and blocks.
877  DenseSet<void *> erased;
878  };
879 
880  //===--------------------------------------------------------------------===//
881  // State
882  //===--------------------------------------------------------------------===//
883 
884  /// MLIR context.
886 
887  /// A rewriter that keeps track of ops/block that were already erased and
888  /// skips duplicate op/block erasures. This rewriter is used during the
889  /// "cleanup" phase.
891 
892  // Mapping between replaced values that differ in type. This happens when
893  // replacing a value with one of a different type.
894  ConversionValueMapping mapping;
895 
896  /// Ordered list of block operations (creations, splits, motions).
898 
899  /// A set of operations that should no longer be considered for legalization.
900  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
901  /// tracked separately.
903 
904  /// A set of operations that were replaced/erased. Such ops are not erased
905  /// immediately but only when the dialect conversion succeeds. In the mean
906  /// time, they should no longer be considered for legalization and any attempt
907  /// to modify/access them is invalid rewriter API usage.
909 
910  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
911  /// to the corresponding rewrite objects.
914 
915  /// The current type converter, or nullptr if no type converter is currently
916  /// active.
917  const TypeConverter *currentTypeConverter = nullptr;
918 
919  /// A mapping of regions to type converters that should be used when
920  /// converting the arguments of blocks within that region.
922 
923  /// Dialect conversion configuration.
925 
926 #ifndef NDEBUG
927  /// A set of operations that have pending updates. This tracking isn't
928  /// strictly necessary, and is thus only active during debug builds for extra
929  /// verification.
931 
932  /// A logger used to emit diagnostics during the conversion process.
933  llvm::ScopedPrinter logger{llvm::dbgs()};
934 #endif
935 };
936 } // namespace detail
937 } // namespace mlir
938 
939 const ConversionConfig &IRRewrite::getConfig() const {
940  return rewriterImpl.config;
941 }
942 
943 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
944  // Inform the listener about all IR modifications that have already taken
945  // place: References to the original block have been replaced with the new
946  // block.
947  if (auto *listener =
948  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
949  for (Operation *op : block->getUsers())
950  listener->notifyOperationModified(op);
951 }
952 
953 void BlockTypeConversionRewrite::rollback() {
954  block->replaceAllUsesWith(origBlock);
955 }
956 
957 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
958  Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
959  if (!repl)
960  return;
961 
962  if (isa<BlockArgument>(repl)) {
963  rewriter.replaceAllUsesWith(arg, repl);
964  return;
965  }
966 
967  // If the replacement value is an operation, we check to make sure that we
968  // don't replace uses that are within the parent operation of the
969  // replacement value.
970  Operation *replOp = cast<OpResult>(repl).getOwner();
971  Block *replBlock = replOp->getBlock();
972  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
973  Operation *user = operand.getOwner();
974  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
975  });
976 }
977 
978 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
979 
980 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
981  auto *listener =
982  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
983 
984  // Compute replacement values.
985  SmallVector<Value> replacements =
986  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
987  return rewriterImpl.mapping.lookupOrNull(result, result.getType());
988  });
989 
990  // Notify the listener that the operation is about to be replaced.
991  if (listener)
992  listener->notifyOperationReplaced(op, replacements);
993 
994  // Replace all uses with the new values.
995  for (auto [result, newValue] :
996  llvm::zip_equal(op->getResults(), replacements))
997  if (newValue)
998  rewriter.replaceAllUsesWith(result, newValue);
999 
1000  // The original op will be erased, so remove it from the set of unlegalized
1001  // ops.
1002  if (getConfig().unlegalizedOps)
1003  getConfig().unlegalizedOps->erase(op);
1004 
1005  // Notify the listener that the operation (and its nested operations) was
1006  // erased.
1007  if (listener) {
1009  [&](Operation *op) { listener->notifyOperationErased(op); });
1010  }
1011 
1012  // Do not erase the operation yet. It may still be referenced in `mapping`.
1013  // Just unlink it for now and erase it during cleanup.
1014  op->getBlock()->getOperations().remove(op);
1015 }
1016 
1017 void ReplaceOperationRewrite::rollback() {
1018  for (auto result : op->getResults())
1019  rewriterImpl.mapping.erase(result);
1020 }
1021 
1022 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1023  rewriter.eraseOp(op);
1024 }
1025 
1026 void CreateOperationRewrite::rollback() {
1027  for (Region &region : op->getRegions()) {
1028  while (!region.getBlocks().empty())
1029  region.getBlocks().remove(region.getBlocks().begin());
1030  }
1031  op->dropAllUses();
1032  op->erase();
1033 }
1034 
1035 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1036  ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1037  const TypeConverter *converter, MaterializationKind kind)
1038  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1039  converterAndKind(converter, kind) {
1040  rewriterImpl.unresolvedMaterializations[op] = this;
1041 }
1042 
1043 void UnresolvedMaterializationRewrite::rollback() {
1044  if (getMaterializationKind() == MaterializationKind::Target) {
1045  for (Value input : op->getOperands())
1046  rewriterImpl.mapping.erase(input);
1047  }
1048  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1049  op->erase();
1050 }
1051 
1053  // Commit all rewrites.
1054  IRRewriter rewriter(context, config.listener);
1055  for (auto &rewrite : rewrites)
1056  rewrite->commit(rewriter);
1057 
1058  // Clean up all rewrites.
1059  for (auto &rewrite : rewrites)
1060  rewrite->cleanup(eraseRewriter);
1061 }
1062 
1063 //===----------------------------------------------------------------------===//
1064 // State Management
1065 
1067  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1068 }
1069 
1071  // Undo any rewrites.
1072  undoRewrites(state.numRewrites);
1073 
1074  // Pop all of the recorded ignored operations that are no longer valid.
1075  while (ignoredOps.size() != state.numIgnoredOperations)
1076  ignoredOps.pop_back();
1077 
1078  while (replacedOps.size() != state.numReplacedOps)
1079  replacedOps.pop_back();
1080 }
1081 
1082 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1083  for (auto &rewrite :
1084  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1085  rewrite->rollback();
1086  rewrites.resize(numRewritesToKeep);
1087 }
1088 
1090  StringRef valueDiagTag, std::optional<Location> inputLoc,
1091  PatternRewriter &rewriter, ValueRange values,
1092  SmallVectorImpl<Value> &remapped) {
1093  remapped.reserve(llvm::size(values));
1094 
1095  for (const auto &it : llvm::enumerate(values)) {
1096  Value operand = it.value();
1097  Type origType = operand.getType();
1098  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1099 
1100  if (!currentTypeConverter) {
1101  // The current pattern does not have a type converter. I.e., it does not
1102  // distinguish between legal and illegal types. For each operand, simply
1103  // pass through the most recently mapped value.
1104  remapped.push_back(mapping.lookupOrDefault(operand));
1105  continue;
1106  }
1107 
1108  // If there is no legal conversion, fail to match this pattern.
1109  SmallVector<Type, 1> legalTypes;
1110  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1111  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1112  diag << "unable to convert type for " << valueDiagTag << " #"
1113  << it.index() << ", type was " << origType;
1114  });
1115  return failure();
1116  }
1117 
1118  if (legalTypes.size() != 1) {
1119  // TODO: Parts of the dialect conversion infrastructure do not support
1120  // 1->N type conversions yet. Therefore, if a type is converted to 0 or
1121  // multiple types, the only thing that we can do for now is passing
1122  // through the most recently mapped value. Fixing this requires
1123  // improvements to the `ConversionValueMapping` (to be able to store 1:N
1124  // mappings) and to the `ConversionPattern` adaptor handling (to be able
1125  // to pass multiple remapped values for a single operand to the adaptor).
1126  remapped.push_back(mapping.lookupOrDefault(operand));
1127  continue;
1128  }
1129 
1130  // Handle 1->1 type conversions.
1131  Type desiredType = legalTypes.front();
1132  // Try to find a mapped value with the desired type. (Or the operand itself
1133  // if the value is not mapped at all.)
1134  Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1135  if (newOperand.getType() != desiredType) {
1136  // If the looked up value's type does not have the desired type, it means
1137  // that the value was replaced with a value of different type and no
1138  // source materialization was created yet.
1140  MaterializationKind::Target, computeInsertPoint(newOperand),
1141  operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
1143  mapping.map(newOperand, castValue);
1144  newOperand = castValue;
1145  }
1146  remapped.push_back(newOperand);
1147  }
1148  return success();
1149 }
1150 
1152  // Check to see if this operation is ignored or was replaced.
1153  return replacedOps.count(op) || ignoredOps.count(op);
1154 }
1155 
1157  // Check to see if this operation was replaced.
1158  return replacedOps.count(op);
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // Type Conversion
1163 
1165  ConversionPatternRewriter &rewriter, Region *region,
1166  const TypeConverter &converter,
1167  TypeConverter::SignatureConversion *entryConversion) {
1168  regionToConverter[region] = &converter;
1169  if (region->empty())
1170  return nullptr;
1171 
1172  // Convert the arguments of each non-entry block within the region.
1173  for (Block &block :
1174  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1175  // Compute the signature for the block with the provided converter.
1176  std::optional<TypeConverter::SignatureConversion> conversion =
1177  converter.convertBlockSignature(&block);
1178  if (!conversion)
1179  return failure();
1180  // Convert the block with the computed signature.
1181  applySignatureConversion(rewriter, &block, &converter, *conversion);
1182  }
1183 
1184  // Convert the entry block. If an entry signature conversion was provided,
1185  // use that one. Otherwise, compute the signature with the type converter.
1186  if (entryConversion)
1187  return applySignatureConversion(rewriter, &region->front(), &converter,
1188  *entryConversion);
1189  std::optional<TypeConverter::SignatureConversion> conversion =
1190  converter.convertBlockSignature(&region->front());
1191  if (!conversion)
1192  return failure();
1193  return applySignatureConversion(rewriter, &region->front(), &converter,
1194  *conversion);
1195 }
1196 
1198  ConversionPatternRewriter &rewriter, Block *block,
1199  const TypeConverter *converter,
1200  TypeConverter::SignatureConversion &signatureConversion) {
1201  OpBuilder::InsertionGuard g(rewriter);
1202 
1203  // If no arguments are being changed or added, there is nothing to do.
1204  unsigned origArgCount = block->getNumArguments();
1205  auto convertedTypes = signatureConversion.getConvertedTypes();
1206  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1207  return block;
1208 
1209  // Compute the locations of all block arguments in the new block.
1210  SmallVector<Location> newLocs(convertedTypes.size(),
1211  rewriter.getUnknownLoc());
1212  for (unsigned i = 0; i < origArgCount; ++i) {
1213  auto inputMap = signatureConversion.getInputMapping(i);
1214  if (!inputMap || inputMap->replacementValue)
1215  continue;
1216  Location origLoc = block->getArgument(i).getLoc();
1217  for (unsigned j = 0; j < inputMap->size; ++j)
1218  newLocs[inputMap->inputNo + j] = origLoc;
1219  }
1220 
1221  // Insert a new block with the converted block argument types and move all ops
1222  // from the old block to the new block.
1223  Block *newBlock =
1224  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1225  convertedTypes, newLocs);
1226 
1227  // If a listener is attached to the dialect conversion, ops cannot be moved
1228  // to the destination block in bulk ("fast path"). This is because at the time
1229  // the notifications are sent, it is unknown which ops were moved. Instead,
1230  // ops should be moved one-by-one ("slow path"), so that a separate
1231  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1232  // a bit more efficient, so we try to do that when possible.
1233  bool fastPath = !config.listener;
1234  if (fastPath) {
1235  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1236  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1237  } else {
1238  while (!block->empty())
1239  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1240  }
1241 
1242  // Replace all uses of the old block with the new block.
1243  block->replaceAllUsesWith(newBlock);
1244 
1245  for (unsigned i = 0; i != origArgCount; ++i) {
1246  BlockArgument origArg = block->getArgument(i);
1247  Type origArgType = origArg.getType();
1248 
1249  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1250  signatureConversion.getInputMapping(i);
1251  if (!inputMap) {
1252  // This block argument was dropped and no replacement value was provided.
1253  // Materialize a replacement value "out of thin air".
1255  MaterializationKind::Source,
1256  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1257  /*inputs=*/ValueRange(),
1258  /*outputType=*/origArgType, converter);
1259  mapping.map(origArg, repl);
1260  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1261  continue;
1262  }
1263 
1264  if (Value repl = inputMap->replacementValue) {
1265  // This block argument was dropped and a replacement value was provided.
1266  assert(inputMap->size == 0 &&
1267  "invalid to provide a replacement value when the argument isn't "
1268  "dropped");
1269  mapping.map(origArg, repl);
1270  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1271  continue;
1272  }
1273 
1274  // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1275  // dialect conversion. Therefore, we need an argument materialization to
1276  // turn the replacement block arguments into a single SSA value that can be
1277  // used as a replacement.
1278  auto replArgs =
1279  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1282  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1283  /*inputs=*/replArgs, origArgType, converter);
1284  mapping.map(origArg, argMat);
1285 
1286  Type legalOutputType;
1287  if (converter) {
1288  legalOutputType = converter->convertType(origArgType);
1289  } else if (replArgs.size() == 1) {
1290  // When there is no type converter, assume that the new block argument
1291  // types are legal. This is reasonable to assume because they were
1292  // specified by the user.
1293  // FIXME: This won't work for 1->N conversions because multiple output
1294  // types are not supported in parts of the dialect conversion. In such a
1295  // case, we currently use the original block argument type (produced by
1296  // the argument materialization).
1297  legalOutputType = replArgs[0].getType();
1298  }
1299  if (legalOutputType && legalOutputType != origArgType) {
1301  MaterializationKind::Target, computeInsertPoint(argMat),
1302  origArg.getLoc(), argMat, legalOutputType, converter);
1303  mapping.map(argMat, targetMat);
1304  }
1305  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1306  }
1307 
1308  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
1309 
1310  // Erase the old block. (It is just unlinked for now and will be erased during
1311  // cleanup.)
1312  rewriter.eraseBlock(block);
1313 
1314  return newBlock;
1315 }
1316 
1317 //===----------------------------------------------------------------------===//
1318 // Materializations
1319 //===----------------------------------------------------------------------===//
1320 
1321 /// Build an unresolved materialization operation given an output type and set
1322 /// of input operands.
1324  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1325  ValueRange inputs, Type outputType, const TypeConverter *converter) {
1326  // Avoid materializing an unnecessary cast.
1327  if (inputs.size() == 1 && inputs.front().getType() == outputType)
1328  return inputs.front();
1329 
1330  // Create an unresolved materialization. We use a new OpBuilder to avoid
1331  // tracking the materialization like we do for other operations.
1332  OpBuilder builder(outputType.getContext());
1333  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1334  auto convertOp =
1335  builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1336  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1337  return convertOp.getResult(0);
1338 }
1339 
1340 //===----------------------------------------------------------------------===//
1341 // Rewriter Notification Hooks
1342 
1344  Operation *op, OpBuilder::InsertPoint previous) {
1345  LLVM_DEBUG({
1346  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1347  << ")\n";
1348  });
1349  assert(!wasOpReplaced(op->getParentOp()) &&
1350  "attempting to insert into a block within a replaced/erased op");
1351 
1352  if (!previous.isSet()) {
1353  // This is a newly created op.
1354  appendRewrite<CreateOperationRewrite>(op);
1355  return;
1356  }
1357  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1358  ? nullptr
1359  : &*previous.getPoint();
1360  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1361 }
1362 
1364  ValueRange newValues) {
1365  assert(newValues.size() == op->getNumResults());
1366  assert(!ignoredOps.contains(op) && "operation was already replaced");
1367 
1368  // Create mappings for each of the new result values.
1369  for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
1370  if (!newValue) {
1371  // This result was dropped and no replacement value was provided.
1372  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1373  if (unresolvedMaterializations.contains(castOp)) {
1374  // Do not create another materializations if we are erasing a
1375  // materialization.
1376  continue;
1377  }
1378  }
1379 
1380  // Materialize a replacement value "out of thin air".
1381  newValue = buildUnresolvedMaterialization(
1382  MaterializationKind::Source, computeInsertPoint(result),
1383  result.getLoc(), /*inputs=*/ValueRange(),
1384  /*outputType=*/result.getType(), currentTypeConverter);
1385  }
1386 
1387  // Remap, and check for any result type changes.
1388  mapping.map(result, newValue);
1389  }
1390 
1391  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1392 
1393  // Mark this operation and all nested ops as replaced.
1394  op->walk([&](Operation *op) { replacedOps.insert(op); });
1395 }
1396 
1398  appendRewrite<EraseBlockRewrite>(block);
1399 }
1400 
1402  Block *block, Region *previous, Region::iterator previousIt) {
1403  assert(!wasOpReplaced(block->getParentOp()) &&
1404  "attempting to insert into a region within a replaced/erased op");
1405  LLVM_DEBUG(
1406  {
1407  Operation *parent = block->getParentOp();
1408  if (parent) {
1409  logger.startLine() << "** Insert Block into : '" << parent->getName()
1410  << "'(" << parent << ")\n";
1411  } else {
1412  logger.startLine()
1413  << "** Insert Block into detached Region (nullptr parent op)'";
1414  }
1415  });
1416 
1417  if (!previous) {
1418  // This is a newly created block.
1419  appendRewrite<CreateBlockRewrite>(block);
1420  return;
1421  }
1422  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1423  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1424 }
1425 
1427  Block *block, Block *srcBlock, Block::iterator before) {
1428  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1429 }
1430 
1432  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1433  LLVM_DEBUG({
1435  reasonCallback(diag);
1436  logger.startLine() << "** Failure : " << diag.str() << "\n";
1437  if (config.notifyCallback)
1439  });
1440 }
1441 
1442 //===----------------------------------------------------------------------===//
1443 // ConversionPatternRewriter
1444 //===----------------------------------------------------------------------===//
1445 
1446 ConversionPatternRewriter::ConversionPatternRewriter(
1447  MLIRContext *ctx, const ConversionConfig &config)
1448  : PatternRewriter(ctx),
1449  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1450  setListener(impl.get());
1451 }
1452 
1454 
1456  assert(op && newOp && "expected non-null op");
1457  replaceOp(op, newOp->getResults());
1458 }
1459 
1461  assert(op->getNumResults() == newValues.size() &&
1462  "incorrect # of replacement values");
1463  LLVM_DEBUG({
1464  impl->logger.startLine()
1465  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1466  });
1467  impl->notifyOpReplaced(op, newValues);
1468 }
1469 
1471  LLVM_DEBUG({
1472  impl->logger.startLine()
1473  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1474  });
1475  SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1476  impl->notifyOpReplaced(op, nullRepls);
1477 }
1478 
1480  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1481  "attempting to erase a block within a replaced/erased op");
1482 
1483  // Mark all ops for erasure.
1484  for (Operation &op : *block)
1485  eraseOp(&op);
1486 
1487  // Unlink the block from its parent region. The block is kept in the rewrite
1488  // object and will be actually destroyed when rewrites are applied. This
1489  // allows us to keep the operations in the block live and undo the removal by
1490  // re-inserting the block.
1491  impl->notifyBlockIsBeingErased(block);
1492  block->getParent()->getBlocks().remove(block);
1493 }
1494 
1496  Block *block, TypeConverter::SignatureConversion &conversion,
1497  const TypeConverter *converter) {
1498  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1499  "attempting to apply a signature conversion to a block within a "
1500  "replaced/erased op");
1501  return impl->applySignatureConversion(*this, block, converter, conversion);
1502 }
1503 
1505  Region *region, const TypeConverter &converter,
1506  TypeConverter::SignatureConversion *entryConversion) {
1507  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1508  "attempting to apply a signature conversion to a block within a "
1509  "replaced/erased op");
1510  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1511 }
1512 
1514  Value to) {
1515  LLVM_DEBUG({
1516  Operation *parentOp = from.getOwner()->getParentOp();
1517  impl->logger.startLine() << "** Replace Argument : '" << from
1518  << "'(in region of '" << parentOp->getName()
1519  << "'(" << from.getOwner()->getParentOp() << ")\n";
1520  });
1521  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
1522  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1523 }
1524 
1526  SmallVector<Value> remappedValues;
1527  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1528  remappedValues)))
1529  return nullptr;
1530  return remappedValues.front();
1531 }
1532 
1533 LogicalResult
1535  SmallVectorImpl<Value> &results) {
1536  if (keys.empty())
1537  return success();
1538  return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1539  results);
1540 }
1541 
1543  Block::iterator before,
1544  ValueRange argValues) {
1545 #ifndef NDEBUG
1546  assert(argValues.size() == source->getNumArguments() &&
1547  "incorrect # of argument replacement values");
1548  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1549  "attempting to inline a block from a replaced/erased op");
1550  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1551  "attempting to inline a block into a replaced/erased op");
1552  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1553  // The source block will be deleted, so it should not have any users (i.e.,
1554  // there should be no predecessors).
1555  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1556  "expected 'source' to have no predecessors");
1557 #endif // NDEBUG
1558 
1559  // If a listener is attached to the dialect conversion, ops cannot be moved
1560  // to the destination block in bulk ("fast path"). This is because at the time
1561  // the notifications are sent, it is unknown which ops were moved. Instead,
1562  // ops should be moved one-by-one ("slow path"), so that a separate
1563  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1564  // a bit more efficient, so we try to do that when possible.
1565  bool fastPath = !impl->config.listener;
1566 
1567  if (fastPath)
1568  impl->notifyBlockBeingInlined(dest, source, before);
1569 
1570  // Replace all uses of block arguments.
1571  for (auto it : llvm::zip(source->getArguments(), argValues))
1572  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1573 
1574  if (fastPath) {
1575  // Move all ops at once.
1576  dest->getOperations().splice(before, source->getOperations());
1577  } else {
1578  // Move op by op.
1579  while (!source->empty())
1580  moveOpBefore(&source->front(), dest, before);
1581  }
1582 
1583  // Erase the source block.
1584  eraseBlock(source);
1585 }
1586 
1588  assert(!impl->wasOpReplaced(op) &&
1589  "attempting to modify a replaced/erased op");
1590 #ifndef NDEBUG
1591  impl->pendingRootUpdates.insert(op);
1592 #endif
1593  impl->appendRewrite<ModifyOperationRewrite>(op);
1594 }
1595 
1597  assert(!impl->wasOpReplaced(op) &&
1598  "attempting to modify a replaced/erased op");
1600  // There is nothing to do here, we only need to track the operation at the
1601  // start of the update.
1602 #ifndef NDEBUG
1603  assert(impl->pendingRootUpdates.erase(op) &&
1604  "operation did not have a pending in-place update");
1605 #endif
1606 }
1607 
1609 #ifndef NDEBUG
1610  assert(impl->pendingRootUpdates.erase(op) &&
1611  "operation did not have a pending in-place update");
1612 #endif
1613  // Erase the last update for this operation.
1614  auto it = llvm::find_if(
1615  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1616  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1617  return modifyRewrite && modifyRewrite->getOperation() == op;
1618  });
1619  assert(it != impl->rewrites.rend() && "no root update started on op");
1620  (*it)->rollback();
1621  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1622  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1623 }
1624 
1626  return *impl;
1627 }
1628 
1629 //===----------------------------------------------------------------------===//
1630 // ConversionPattern
1631 //===----------------------------------------------------------------------===//
1632 
1633 LogicalResult
1635  PatternRewriter &rewriter) const {
1636  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1637  auto &rewriterImpl = dialectRewriter.getImpl();
1638 
1639  // Track the current conversion pattern type converter in the rewriter.
1640  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1641  getTypeConverter());
1642 
1643  // Remap the operands of the operation.
1644  SmallVector<Value, 4> operands;
1645  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1646  op->getOperands(), operands))) {
1647  return failure();
1648  }
1649  return matchAndRewrite(op, operands, dialectRewriter);
1650 }
1651 
1652 //===----------------------------------------------------------------------===//
1653 // OperationLegalizer
1654 //===----------------------------------------------------------------------===//
1655 
1656 namespace {
1657 /// A set of rewrite patterns that can be used to legalize a given operation.
1658 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1659 
1660 /// This class defines a recursive operation legalizer.
1661 class OperationLegalizer {
1662 public:
1663  using LegalizationAction = ConversionTarget::LegalizationAction;
1664 
1665  OperationLegalizer(const ConversionTarget &targetInfo,
1666  const FrozenRewritePatternSet &patterns,
1667  const ConversionConfig &config);
1668 
1669  /// Returns true if the given operation is known to be illegal on the target.
1670  bool isIllegal(Operation *op) const;
1671 
1672  /// Attempt to legalize the given operation. Returns success if the operation
1673  /// was legalized, failure otherwise.
1674  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1675 
1676  /// Returns the conversion target in use by the legalizer.
1677  const ConversionTarget &getTarget() { return target; }
1678 
1679 private:
1680  /// Attempt to legalize the given operation by folding it.
1681  LogicalResult legalizeWithFold(Operation *op,
1682  ConversionPatternRewriter &rewriter);
1683 
1684  /// Attempt to legalize the given operation by applying a pattern. Returns
1685  /// success if the operation was legalized, failure otherwise.
1686  LogicalResult legalizeWithPattern(Operation *op,
1687  ConversionPatternRewriter &rewriter);
1688 
1689  /// Return true if the given pattern may be applied to the given operation,
1690  /// false otherwise.
1691  bool canApplyPattern(Operation *op, const Pattern &pattern,
1692  ConversionPatternRewriter &rewriter);
1693 
1694  /// Legalize the resultant IR after successfully applying the given pattern.
1695  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1696  ConversionPatternRewriter &rewriter,
1697  RewriterState &curState);
1698 
1699  /// Legalizes the actions registered during the execution of a pattern.
1700  LogicalResult
1701  legalizePatternBlockRewrites(Operation *op,
1702  ConversionPatternRewriter &rewriter,
1704  RewriterState &state, RewriterState &newState);
1705  LogicalResult legalizePatternCreatedOperations(
1707  RewriterState &state, RewriterState &newState);
1708  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1710  RewriterState &state,
1711  RewriterState &newState);
1712 
1713  //===--------------------------------------------------------------------===//
1714  // Cost Model
1715  //===--------------------------------------------------------------------===//
1716 
1717  /// Build an optimistic legalization graph given the provided patterns. This
1718  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1719  /// patterns for operations that are not directly legal, but may be
1720  /// transitively legal for the current target given the provided patterns.
1721  void buildLegalizationGraph(
1722  LegalizationPatterns &anyOpLegalizerPatterns,
1724 
1725  /// Compute the benefit of each node within the computed legalization graph.
1726  /// This orders the patterns within 'legalizerPatterns' based upon two
1727  /// criteria:
1728  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1729  /// represent the more direct mapping to the target.
1730  /// 2) When comparing patterns with the same legalization depth, prefer the
1731  /// pattern with the highest PatternBenefit. This allows for users to
1732  /// prefer specific legalizations over others.
1733  void computeLegalizationGraphBenefit(
1734  LegalizationPatterns &anyOpLegalizerPatterns,
1736 
1737  /// Compute the legalization depth when legalizing an operation of the given
1738  /// type.
1739  unsigned computeOpLegalizationDepth(
1740  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1742 
1743  /// Apply the conversion cost model to the given set of patterns, and return
1744  /// the smallest legalization depth of any of the patterns. See
1745  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1746  unsigned applyCostModelToPatterns(
1747  LegalizationPatterns &patterns,
1748  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1750 
1751  /// The current set of patterns that have been applied.
1752  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1753 
1754  /// The legalization information provided by the target.
1755  const ConversionTarget &target;
1756 
1757  /// The pattern applicator to use for conversions.
1758  PatternApplicator applicator;
1759 
1760  /// Dialect conversion configuration.
1761  const ConversionConfig &config;
1762 };
1763 } // namespace
1764 
1765 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1766  const FrozenRewritePatternSet &patterns,
1767  const ConversionConfig &config)
1768  : target(targetInfo), applicator(patterns), config(config) {
1769  // The set of patterns that can be applied to illegal operations to transform
1770  // them into legal ones.
1772  LegalizationPatterns anyOpLegalizerPatterns;
1773 
1774  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1775  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1776 }
1777 
1778 bool OperationLegalizer::isIllegal(Operation *op) const {
1779  return target.isIllegal(op);
1780 }
1781 
1782 LogicalResult
1783 OperationLegalizer::legalize(Operation *op,
1784  ConversionPatternRewriter &rewriter) {
1785 #ifndef NDEBUG
1786  const char *logLineComment =
1787  "//===-------------------------------------------===//\n";
1788 
1789  auto &logger = rewriter.getImpl().logger;
1790 #endif
1791  LLVM_DEBUG({
1792  logger.getOStream() << "\n";
1793  logger.startLine() << logLineComment;
1794  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1795  << op << ") {\n";
1796  logger.indent();
1797 
1798  // If the operation has no regions, just print it here.
1799  if (op->getNumRegions() == 0) {
1800  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1801  logger.getOStream() << "\n\n";
1802  }
1803  });
1804 
1805  // Check if this operation is legal on the target.
1806  if (auto legalityInfo = target.isLegal(op)) {
1807  LLVM_DEBUG({
1808  logSuccess(
1809  logger, "operation marked legal by the target{0}",
1810  legalityInfo->isRecursivelyLegal
1811  ? "; NOTE: operation is recursively legal; skipping internals"
1812  : "");
1813  logger.startLine() << logLineComment;
1814  });
1815 
1816  // If this operation is recursively legal, mark its children as ignored so
1817  // that we don't consider them for legalization.
1818  if (legalityInfo->isRecursivelyLegal) {
1819  op->walk([&](Operation *nested) {
1820  if (op != nested)
1821  rewriter.getImpl().ignoredOps.insert(nested);
1822  });
1823  }
1824 
1825  return success();
1826  }
1827 
1828  // Check to see if the operation is ignored and doesn't need to be converted.
1829  if (rewriter.getImpl().isOpIgnored(op)) {
1830  LLVM_DEBUG({
1831  logSuccess(logger, "operation marked 'ignored' during conversion");
1832  logger.startLine() << logLineComment;
1833  });
1834  return success();
1835  }
1836 
1837  // If the operation isn't legal, try to fold it in-place.
1838  // TODO: Should we always try to do this, even if the op is
1839  // already legal?
1840  if (succeeded(legalizeWithFold(op, rewriter))) {
1841  LLVM_DEBUG({
1842  logSuccess(logger, "operation was folded");
1843  logger.startLine() << logLineComment;
1844  });
1845  return success();
1846  }
1847 
1848  // Otherwise, we need to apply a legalization pattern to this operation.
1849  if (succeeded(legalizeWithPattern(op, rewriter))) {
1850  LLVM_DEBUG({
1851  logSuccess(logger, "");
1852  logger.startLine() << logLineComment;
1853  });
1854  return success();
1855  }
1856 
1857  LLVM_DEBUG({
1858  logFailure(logger, "no matched legalization pattern");
1859  logger.startLine() << logLineComment;
1860  });
1861  return failure();
1862 }
1863 
1864 LogicalResult
1865 OperationLegalizer::legalizeWithFold(Operation *op,
1866  ConversionPatternRewriter &rewriter) {
1867  auto &rewriterImpl = rewriter.getImpl();
1868  RewriterState curState = rewriterImpl.getCurrentState();
1869 
1870  LLVM_DEBUG({
1871  rewriterImpl.logger.startLine() << "* Fold {\n";
1872  rewriterImpl.logger.indent();
1873  });
1874 
1875  // Try to fold the operation.
1876  SmallVector<Value, 2> replacementValues;
1877  rewriter.setInsertionPoint(op);
1878  if (failed(rewriter.tryFold(op, replacementValues))) {
1879  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1880  return failure();
1881  }
1882  // An empty list of replacement values indicates that the fold was in-place.
1883  // As the operation changed, a new legalization needs to be attempted.
1884  if (replacementValues.empty())
1885  return legalize(op, rewriter);
1886 
1887  // Insert a replacement for 'op' with the folded replacement values.
1888  rewriter.replaceOp(op, replacementValues);
1889 
1890  // Recursively legalize any new constant operations.
1891  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
1892  i != e; ++i) {
1893  auto *createOp =
1894  dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
1895  if (!createOp)
1896  continue;
1897  if (failed(legalize(createOp->getOperation(), rewriter))) {
1898  LLVM_DEBUG(logFailure(rewriterImpl.logger,
1899  "failed to legalize generated constant '{0}'",
1900  createOp->getOperation()->getName()));
1901  rewriterImpl.resetState(curState);
1902  return failure();
1903  }
1904  }
1905 
1906  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1907  return success();
1908 }
1909 
1910 LogicalResult
1911 OperationLegalizer::legalizeWithPattern(Operation *op,
1912  ConversionPatternRewriter &rewriter) {
1913  auto &rewriterImpl = rewriter.getImpl();
1914 
1915  // Functor that returns if the given pattern may be applied.
1916  auto canApply = [&](const Pattern &pattern) {
1917  bool canApply = canApplyPattern(op, pattern, rewriter);
1918  if (canApply && config.listener)
1919  config.listener->notifyPatternBegin(pattern, op);
1920  return canApply;
1921  };
1922 
1923  // Functor that cleans up the rewriter state after a pattern failed to match.
1924  RewriterState curState = rewriterImpl.getCurrentState();
1925  auto onFailure = [&](const Pattern &pattern) {
1926  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
1927  LLVM_DEBUG({
1928  logFailure(rewriterImpl.logger, "pattern failed to match");
1929  if (rewriterImpl.config.notifyCallback) {
1931  diag << "Failed to apply pattern \"" << pattern.getDebugName()
1932  << "\" on op:\n"
1933  << *op;
1934  rewriterImpl.config.notifyCallback(diag);
1935  }
1936  });
1937  if (config.listener)
1938  config.listener->notifyPatternEnd(pattern, failure());
1939  rewriterImpl.resetState(curState);
1940  appliedPatterns.erase(&pattern);
1941  };
1942 
1943  // Functor that performs additional legalization when a pattern is
1944  // successfully applied.
1945  auto onSuccess = [&](const Pattern &pattern) {
1946  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
1947  auto result = legalizePatternResult(op, pattern, rewriter, curState);
1948  appliedPatterns.erase(&pattern);
1949  if (failed(result))
1950  rewriterImpl.resetState(curState);
1951  if (config.listener)
1952  config.listener->notifyPatternEnd(pattern, result);
1953  return result;
1954  };
1955 
1956  // Try to match and rewrite a pattern on this operation.
1957  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1958  onSuccess);
1959 }
1960 
1961 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1962  ConversionPatternRewriter &rewriter) {
1963  LLVM_DEBUG({
1964  auto &os = rewriter.getImpl().logger;
1965  os.getOStream() << "\n";
1966  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1967  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
1968  os.getOStream() << ")' {\n";
1969  os.indent();
1970  });
1971 
1972  // Ensure that we don't cycle by not allowing the same pattern to be
1973  // applied twice in the same recursion stack if it is not known to be safe.
1974  if (!pattern.hasBoundedRewriteRecursion() &&
1975  !appliedPatterns.insert(&pattern).second) {
1976  LLVM_DEBUG(
1977  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
1978  return false;
1979  }
1980  return true;
1981 }
1982 
1983 LogicalResult
1984 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
1985  ConversionPatternRewriter &rewriter,
1986  RewriterState &curState) {
1987  auto &impl = rewriter.getImpl();
1988 
1989 #ifndef NDEBUG
1990  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
1991  // Check that the root was either replaced or updated in place.
1992  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
1993  auto replacedRoot = [&] {
1994  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
1995  };
1996  auto updatedRootInPlace = [&] {
1997  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
1998  };
1999  assert((replacedRoot() || updatedRootInPlace()) &&
2000  "expected pattern to replace the root operation");
2001 #endif // NDEBUG
2002 
2003  // Legalize each of the actions registered during application.
2004  RewriterState newState = impl.getCurrentState();
2005  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2006  newState)) ||
2007  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2008  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2009  newState))) {
2010  return failure();
2011  }
2012 
2013  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2014  return success();
2015 }
2016 
2017 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2018  Operation *op, ConversionPatternRewriter &rewriter,
2019  ConversionPatternRewriterImpl &impl, RewriterState &state,
2020  RewriterState &newState) {
2021  SmallPtrSet<Operation *, 16> operationsToIgnore;
2022 
2023  // If the pattern moved or created any blocks, make sure the types of block
2024  // arguments get legalized.
2025  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2026  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2027  if (!rewrite)
2028  continue;
2029  Block *block = rewrite->getBlock();
2030  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2031  ReplaceBlockArgRewrite>(rewrite))
2032  continue;
2033  // Only check blocks outside of the current operation.
2034  Operation *parentOp = block->getParentOp();
2035  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2036  continue;
2037 
2038  // If the region of the block has a type converter, try to convert the block
2039  // directly.
2040  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2041  std::optional<TypeConverter::SignatureConversion> conversion =
2042  converter->convertBlockSignature(block);
2043  if (!conversion) {
2044  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2045  "block"));
2046  return failure();
2047  }
2048  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2049  continue;
2050  }
2051 
2052  // Otherwise, check that this operation isn't one generated by this pattern.
2053  // This is because we will attempt to legalize the parent operation, and
2054  // blocks in regions created by this pattern will already be legalized later
2055  // on. If we haven't built the set yet, build it now.
2056  if (operationsToIgnore.empty()) {
2057  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2058  ++i) {
2059  auto *createOp =
2060  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2061  if (!createOp)
2062  continue;
2063  operationsToIgnore.insert(createOp->getOperation());
2064  }
2065  }
2066 
2067  // If this operation should be considered for re-legalization, try it.
2068  if (operationsToIgnore.insert(parentOp).second &&
2069  failed(legalize(parentOp, rewriter))) {
2070  LLVM_DEBUG(logFailure(impl.logger,
2071  "operation '{0}'({1}) became illegal after rewrite",
2072  parentOp->getName(), parentOp));
2073  return failure();
2074  }
2075  }
2076  return success();
2077 }
2078 
2079 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2081  RewriterState &state, RewriterState &newState) {
2082  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2083  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2084  if (!createOp)
2085  continue;
2086  Operation *op = createOp->getOperation();
2087  if (failed(legalize(op, rewriter))) {
2088  LLVM_DEBUG(logFailure(impl.logger,
2089  "failed to legalize generated operation '{0}'({1})",
2090  op->getName(), op));
2091  return failure();
2092  }
2093  }
2094  return success();
2095 }
2096 
2097 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2099  RewriterState &state, RewriterState &newState) {
2100  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2101  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2102  if (!rewrite)
2103  continue;
2104  Operation *op = rewrite->getOperation();
2105  if (failed(legalize(op, rewriter))) {
2106  LLVM_DEBUG(logFailure(
2107  impl.logger, "failed to legalize operation updated in-place '{0}'",
2108  op->getName()));
2109  return failure();
2110  }
2111  }
2112  return success();
2113 }
2114 
2115 //===----------------------------------------------------------------------===//
2116 // Cost Model
2117 
2118 void OperationLegalizer::buildLegalizationGraph(
2119  LegalizationPatterns &anyOpLegalizerPatterns,
2120  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2121  // A mapping between an operation and a set of operations that can be used to
2122  // generate it.
2124  // A mapping between an operation and any currently invalid patterns it has.
2126  // A worklist of patterns to consider for legality.
2127  SetVector<const Pattern *> patternWorklist;
2128 
2129  // Build the mapping from operations to the parent ops that may generate them.
2130  applicator.walkAllPatterns([&](const Pattern &pattern) {
2131  std::optional<OperationName> root = pattern.getRootKind();
2132 
2133  // If the pattern has no specific root, we can't analyze the relationship
2134  // between the root op and generated operations. Given that, add all such
2135  // patterns to the legalization set.
2136  if (!root) {
2137  anyOpLegalizerPatterns.push_back(&pattern);
2138  return;
2139  }
2140 
2141  // Skip operations that are always known to be legal.
2142  if (target.getOpAction(*root) == LegalizationAction::Legal)
2143  return;
2144 
2145  // Add this pattern to the invalid set for the root op and record this root
2146  // as a parent for any generated operations.
2147  invalidPatterns[*root].insert(&pattern);
2148  for (auto op : pattern.getGeneratedOps())
2149  parentOps[op].insert(*root);
2150 
2151  // Add this pattern to the worklist.
2152  patternWorklist.insert(&pattern);
2153  });
2154 
2155  // If there are any patterns that don't have a specific root kind, we can't
2156  // make direct assumptions about what operations will never be legalized.
2157  // Note: Technically we could, but it would require an analysis that may
2158  // recurse into itself. It would be better to perform this kind of filtering
2159  // at a higher level than here anyways.
2160  if (!anyOpLegalizerPatterns.empty()) {
2161  for (const Pattern *pattern : patternWorklist)
2162  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2163  return;
2164  }
2165 
2166  while (!patternWorklist.empty()) {
2167  auto *pattern = patternWorklist.pop_back_val();
2168 
2169  // Check to see if any of the generated operations are invalid.
2170  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2171  std::optional<LegalizationAction> action = target.getOpAction(op);
2172  return !legalizerPatterns.count(op) &&
2173  (!action || action == LegalizationAction::Illegal);
2174  }))
2175  continue;
2176 
2177  // Otherwise, if all of the generated operation are valid, this op is now
2178  // legal so add all of the child patterns to the worklist.
2179  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2180  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2181 
2182  // Add any invalid patterns of the parent operations to see if they have now
2183  // become legal.
2184  for (auto op : parentOps[*pattern->getRootKind()])
2185  patternWorklist.set_union(invalidPatterns[op]);
2186  }
2187 }
2188 
2189 void OperationLegalizer::computeLegalizationGraphBenefit(
2190  LegalizationPatterns &anyOpLegalizerPatterns,
2191  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2192  // The smallest pattern depth, when legalizing an operation.
2193  DenseMap<OperationName, unsigned> minOpPatternDepth;
2194 
2195  // For each operation that is transitively legal, compute a cost for it.
2196  for (auto &opIt : legalizerPatterns)
2197  if (!minOpPatternDepth.count(opIt.first))
2198  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2199  legalizerPatterns);
2200 
2201  // Apply the cost model to the patterns that can match any operation. Those
2202  // with a specific operation type are already resolved when computing the op
2203  // legalization depth.
2204  if (!anyOpLegalizerPatterns.empty())
2205  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2206  legalizerPatterns);
2207 
2208  // Apply a cost model to the pattern applicator. We order patterns first by
2209  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2210  // decreasing benefit.
2211  applicator.applyCostModel([&](const Pattern &pattern) {
2212  ArrayRef<const Pattern *> orderedPatternList;
2213  if (std::optional<OperationName> rootName = pattern.getRootKind())
2214  orderedPatternList = legalizerPatterns[*rootName];
2215  else
2216  orderedPatternList = anyOpLegalizerPatterns;
2217 
2218  // If the pattern is not found, then it was removed and cannot be matched.
2219  auto *it = llvm::find(orderedPatternList, &pattern);
2220  if (it == orderedPatternList.end())
2222 
2223  // Patterns found earlier in the list have higher benefit.
2224  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2225  });
2226 }
2227 
2228 unsigned OperationLegalizer::computeOpLegalizationDepth(
2229  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2230  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2231  // Check for existing depth.
2232  auto depthIt = minOpPatternDepth.find(op);
2233  if (depthIt != minOpPatternDepth.end())
2234  return depthIt->second;
2235 
2236  // If a mapping for this operation does not exist, then this operation
2237  // is always legal. Return 0 as the depth for a directly legal operation.
2238  auto opPatternsIt = legalizerPatterns.find(op);
2239  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2240  return 0u;
2241 
2242  // Record this initial depth in case we encounter this op again when
2243  // recursively computing the depth.
2244  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2245 
2246  // Apply the cost model to the operation patterns, and update the minimum
2247  // depth.
2248  unsigned minDepth = applyCostModelToPatterns(
2249  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2250  minOpPatternDepth[op] = minDepth;
2251  return minDepth;
2252 }
2253 
2254 unsigned OperationLegalizer::applyCostModelToPatterns(
2255  LegalizationPatterns &patterns,
2256  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2257  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2258  unsigned minDepth = std::numeric_limits<unsigned>::max();
2259 
2260  // Compute the depth for each pattern within the set.
2262  patternsByDepth.reserve(patterns.size());
2263  for (const Pattern *pattern : patterns) {
2264  unsigned depth = 1;
2265  for (auto generatedOp : pattern->getGeneratedOps()) {
2266  unsigned generatedOpDepth = computeOpLegalizationDepth(
2267  generatedOp, minOpPatternDepth, legalizerPatterns);
2268  depth = std::max(depth, generatedOpDepth + 1);
2269  }
2270  patternsByDepth.emplace_back(pattern, depth);
2271 
2272  // Update the minimum depth of the pattern list.
2273  minDepth = std::min(minDepth, depth);
2274  }
2275 
2276  // If the operation only has one legalization pattern, there is no need to
2277  // sort them.
2278  if (patternsByDepth.size() == 1)
2279  return minDepth;
2280 
2281  // Sort the patterns by those likely to be the most beneficial.
2282  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2283  [](const std::pair<const Pattern *, unsigned> &lhs,
2284  const std::pair<const Pattern *, unsigned> &rhs) {
2285  // First sort by the smaller pattern legalization
2286  // depth.
2287  if (lhs.second != rhs.second)
2288  return lhs.second < rhs.second;
2289 
2290  // Then sort by the larger pattern benefit.
2291  auto lhsBenefit = lhs.first->getBenefit();
2292  auto rhsBenefit = rhs.first->getBenefit();
2293  return lhsBenefit > rhsBenefit;
2294  });
2295 
2296  // Update the legalization pattern to use the new sorted list.
2297  patterns.clear();
2298  for (auto &patternIt : patternsByDepth)
2299  patterns.push_back(patternIt.first);
2300  return minDepth;
2301 }
2302 
2303 //===----------------------------------------------------------------------===//
2304 // OperationConverter
2305 //===----------------------------------------------------------------------===//
2306 namespace {
2307 enum OpConversionMode {
2308  /// In this mode, the conversion will ignore failed conversions to allow
2309  /// illegal operations to co-exist in the IR.
2310  Partial,
2311 
2312  /// In this mode, all operations must be legal for the given target for the
2313  /// conversion to succeed.
2314  Full,
2315 
2316  /// In this mode, operations are analyzed for legality. No actual rewrites are
2317  /// applied to the operations on success.
2318  Analysis,
2319 };
2320 } // namespace
2321 
2322 namespace mlir {
2323 // This class converts operations to a given conversion target via a set of
2324 // rewrite patterns. The conversion behaves differently depending on the
2325 // conversion mode.
2327  explicit OperationConverter(const ConversionTarget &target,
2328  const FrozenRewritePatternSet &patterns,
2329  const ConversionConfig &config,
2330  OpConversionMode mode)
2331  : config(config), opLegalizer(target, patterns, this->config),
2332  mode(mode) {}
2333 
2334  /// Converts the given operations to the conversion target.
2335  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2336 
2337 private:
2338  /// Converts an operation with the given rewriter.
2339  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2340 
2341  /// This method is called after the conversion process to legalize any
2342  /// remaining artifacts and complete the conversion.
2343  void finalize(ConversionPatternRewriter &rewriter);
2344 
2345  /// Dialect conversion configuration.
2346  ConversionConfig config;
2347 
2348  /// The legalizer to use when converting operations.
2349  OperationLegalizer opLegalizer;
2350 
2351  /// The conversion mode to use when legalizing operations.
2352  OpConversionMode mode;
2353 };
2354 } // namespace mlir
2355 
2356 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2357  Operation *op) {
2358  // Legalize the given operation.
2359  if (failed(opLegalizer.legalize(op, rewriter))) {
2360  // Handle the case of a failed conversion for each of the different modes.
2361  // Full conversions expect all operations to be converted.
2362  if (mode == OpConversionMode::Full)
2363  return op->emitError()
2364  << "failed to legalize operation '" << op->getName() << "'";
2365  // Partial conversions allow conversions to fail iff the operation was not
2366  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2367  // set, non-legalizable ops are added to that set.
2368  if (mode == OpConversionMode::Partial) {
2369  if (opLegalizer.isIllegal(op))
2370  return op->emitError()
2371  << "failed to legalize operation '" << op->getName()
2372  << "' that was explicitly marked illegal";
2373  if (config.unlegalizedOps)
2374  config.unlegalizedOps->insert(op);
2375  }
2376  } else if (mode == OpConversionMode::Analysis) {
2377  // Analysis conversions don't fail if any operations fail to legalize,
2378  // they are only interested in the operations that were successfully
2379  // legalized.
2380  if (config.legalizableOps)
2381  config.legalizableOps->insert(op);
2382  }
2383  return success();
2384 }
2385 
2386 static LogicalResult
2388  UnresolvedMaterializationRewrite *rewrite) {
2389  UnrealizedConversionCastOp op = rewrite->getOperation();
2390  assert(!op.use_empty() &&
2391  "expected that dead materializations have already been DCE'd");
2392  Operation::operand_range inputOperands = op.getOperands();
2393  Type outputType = op.getResultTypes()[0];
2394 
2395  // Try to materialize the conversion.
2396  if (const TypeConverter *converter = rewrite->getConverter()) {
2397  rewriter.setInsertionPoint(op);
2398  Value newMaterialization;
2399  switch (rewrite->getMaterializationKind()) {
2401  // Try to materialize an argument conversion.
2402  newMaterialization = converter->materializeArgumentConversion(
2403  rewriter, op->getLoc(), outputType, inputOperands);
2404  if (newMaterialization)
2405  break;
2406  // If an argument materialization failed, fallback to trying a target
2407  // materialization.
2408  [[fallthrough]];
2409  case MaterializationKind::Target:
2410  newMaterialization = converter->materializeTargetConversion(
2411  rewriter, op->getLoc(), outputType, inputOperands);
2412  break;
2413  case MaterializationKind::Source:
2414  newMaterialization = converter->materializeSourceConversion(
2415  rewriter, op->getLoc(), outputType, inputOperands);
2416  break;
2417  }
2418  if (newMaterialization) {
2419  assert(newMaterialization.getType() == outputType &&
2420  "materialization callback produced value of incorrect type");
2421  rewriter.replaceOp(op, newMaterialization);
2422  return success();
2423  }
2424  }
2425 
2427  << "failed to legalize unresolved materialization "
2428  "from ("
2429  << inputOperands.getTypes() << ") to " << outputType
2430  << " that remained live after conversion";
2431  diag.attachNote(op->getUsers().begin()->getLoc())
2432  << "see existing live user here: " << *op->getUsers().begin();
2433  return failure();
2434 }
2435 
2437  if (ops.empty())
2438  return success();
2439  const ConversionTarget &target = opLegalizer.getTarget();
2440 
2441  // Compute the set of operations and blocks to convert.
2442  SmallVector<Operation *> toConvert;
2443  for (auto *op : ops) {
2445  [&](Operation *op) {
2446  toConvert.push_back(op);
2447  // Don't check this operation's children for conversion if the
2448  // operation is recursively legal.
2449  auto legalityInfo = target.isLegal(op);
2450  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2451  return WalkResult::skip();
2452  return WalkResult::advance();
2453  });
2454  }
2455 
2456  // Convert each operation and discard rewrites on failure.
2457  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2458  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2459 
2460  for (auto *op : toConvert)
2461  if (failed(convert(rewriter, op)))
2462  return rewriterImpl.undoRewrites(), failure();
2463 
2464  // Now that all of the operations have been converted, finalize the conversion
2465  // process to ensure any lingering conversion artifacts are cleaned up and
2466  // legalized.
2467  finalize(rewriter);
2468 
2469  // After a successful conversion, apply rewrites.
2470  rewriterImpl.applyRewrites();
2471 
2472  // Gather all unresolved materializations.
2475  &materializations = rewriterImpl.unresolvedMaterializations;
2476  for (auto it : materializations) {
2477  if (rewriterImpl.eraseRewriter.wasErased(it.first))
2478  continue;
2479  allCastOps.push_back(it.first);
2480  }
2481 
2482  // Reconcile all UnrealizedConversionCastOps that were inserted by the
2483  // dialect conversion frameworks. (Not the one that were inserted by
2484  // patterns.)
2485  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2486  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2487 
2488  // Try to legalize all unresolved materializations.
2489  if (config.buildMaterializations) {
2490  IRRewriter rewriter(rewriterImpl.context, config.listener);
2491  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2492  auto it = materializations.find(castOp);
2493  assert(it != materializations.end() && "inconsistent state");
2494  if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2495  return failure();
2496  }
2497  }
2498 
2499  return success();
2500 }
2501 
2502 /// Finds a user of the given value, or of any other value that the given value
2503 /// replaced, that was not replaced in the conversion process.
2505  Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2506  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2507  SmallVector<Value> worklist = {initialValue};
2508  while (!worklist.empty()) {
2509  Value value = worklist.pop_back_val();
2510 
2511  // Walk the users of this value to see if there are any live users that
2512  // weren't replaced during conversion.
2513  auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2514  return rewriterImpl.isOpIgnored(user);
2515  });
2516  if (liveUserIt != value.user_end())
2517  return *liveUserIt;
2518  auto mapIt = inverseMapping.find(value);
2519  if (mapIt != inverseMapping.end())
2520  worklist.append(mapIt->second);
2521  }
2522  return nullptr;
2523 }
2524 
2525 /// Helper function that returns the replaced values and the type converter if
2526 /// the given rewrite object is an "operation replacement" or a "block type
2527 /// conversion" (which corresponds to a "block replacement"). Otherwise, return
2528 /// an empty ValueRange and a null type converter pointer.
2529 static std::pair<ValueRange, const TypeConverter *>
2531  if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
2532  return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
2533  if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
2534  return {blockRewrite->getOrigBlock()->getArguments(),
2535  blockRewrite->getConverter()};
2536  return {};
2537 }
2538 
2539 void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2540  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2541  DenseMap<Value, SmallVector<Value>> inverseMapping =
2542  rewriterImpl.mapping.getInverse();
2543 
2544  // Process requested value replacements.
2545  for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
2546  ValueRange replacedValues;
2547  const TypeConverter *converter;
2548  std::tie(replacedValues, converter) =
2549  getReplacedValues(rewriterImpl.rewrites[i].get());
2550  for (Value originalValue : replacedValues) {
2551  // If the type of this value changed and the value is still live, we need
2552  // to materialize a conversion.
2553  if (rewriterImpl.mapping.lookupOrNull(originalValue,
2554  originalValue.getType()))
2555  continue;
2556  Operation *liveUser =
2557  findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
2558  if (!liveUser)
2559  continue;
2560 
2561  // Legalize this value replacement.
2562  Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
2563  assert(newValue && "replacement value not found");
2564  Value castValue = rewriterImpl.buildUnresolvedMaterialization(
2565  MaterializationKind::Source, computeInsertPoint(newValue),
2566  originalValue.getLoc(),
2567  /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
2568  converter);
2569  rewriterImpl.mapping.map(originalValue, castValue);
2570  inverseMapping[castValue].push_back(originalValue);
2571  llvm::erase(inverseMapping[newValue], originalValue);
2572  }
2573  }
2574 }
2575 
2576 //===----------------------------------------------------------------------===//
2577 // Reconcile Unrealized Casts
2578 //===----------------------------------------------------------------------===//
2579 
2582  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2583  SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
2584  castOps.end());
2585  // This set is maintained only if `remainingCastOps` is provided.
2586  DenseSet<Operation *> erasedOps;
2587 
2588  // Helper function that adds all operands to the worklist that are an
2589  // unrealized_conversion_cast op result.
2590  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2591  for (Value v : castOp.getInputs())
2592  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2593  worklist.insert(inputCastOp);
2594  };
2595 
2596  // Helper function that return the unrealized_conversion_cast op that
2597  // defines all inputs of the given op (in the same order). Return "nullptr"
2598  // if there is no such op.
2599  auto getInputCast =
2600  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2601  if (castOp.getInputs().empty())
2602  return {};
2603  auto inputCastOp =
2604  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2605  if (!inputCastOp)
2606  return {};
2607  if (inputCastOp.getOutputs() != castOp.getInputs())
2608  return {};
2609  return inputCastOp;
2610  };
2611 
2612  // Process ops in the worklist bottom-to-top.
2613  while (!worklist.empty()) {
2614  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2615  if (castOp->use_empty()) {
2616  // DCE: If the op has no users, erase it. Add the operands to the
2617  // worklist to find additional DCE opportunities.
2618  enqueueOperands(castOp);
2619  if (remainingCastOps)
2620  erasedOps.insert(castOp.getOperation());
2621  castOp->erase();
2622  continue;
2623  }
2624 
2625  // Traverse the chain of input cast ops to see if an op with the same
2626  // input types can be found.
2627  UnrealizedConversionCastOp nextCast = castOp;
2628  while (nextCast) {
2629  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2630  // Found a cast where the input types match the output types of the
2631  // matched op. We can directly use those inputs and the matched op can
2632  // be removed.
2633  enqueueOperands(castOp);
2634  castOp.replaceAllUsesWith(nextCast.getInputs());
2635  if (remainingCastOps)
2636  erasedOps.insert(castOp.getOperation());
2637  castOp->erase();
2638  break;
2639  }
2640  nextCast = getInputCast(nextCast);
2641  }
2642  }
2643 
2644  if (remainingCastOps)
2645  for (UnrealizedConversionCastOp op : castOps)
2646  if (!erasedOps.contains(op.getOperation()))
2647  remainingCastOps->push_back(op);
2648 }
2649 
2650 //===----------------------------------------------------------------------===//
2651 // Type Conversion
2652 //===----------------------------------------------------------------------===//
2653 
2655  ArrayRef<Type> types) {
2656  assert(!types.empty() && "expected valid types");
2657  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2658  addInputs(types);
2659 }
2660 
2662  assert(!types.empty() &&
2663  "1->0 type remappings don't need to be added explicitly");
2664  argTypes.append(types.begin(), types.end());
2665 }
2666 
2667 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2668  unsigned newInputNo,
2669  unsigned newInputCount) {
2670  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2671  assert(newInputCount != 0 && "expected valid input count");
2672  remappedInputs[origInputNo] =
2673  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2674 }
2675 
2677  Value replacementValue) {
2678  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2679  remappedInputs[origInputNo] =
2680  InputMapping{origInputNo, /*size=*/0, replacementValue};
2681 }
2682 
2684  SmallVectorImpl<Type> &results) const {
2685  {
2686  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2687  std::defer_lock);
2689  cacheReadLock.lock();
2690  auto existingIt = cachedDirectConversions.find(t);
2691  if (existingIt != cachedDirectConversions.end()) {
2692  if (existingIt->second)
2693  results.push_back(existingIt->second);
2694  return success(existingIt->second != nullptr);
2695  }
2696  auto multiIt = cachedMultiConversions.find(t);
2697  if (multiIt != cachedMultiConversions.end()) {
2698  results.append(multiIt->second.begin(), multiIt->second.end());
2699  return success();
2700  }
2701  }
2702  // Walk the added converters in reverse order to apply the most recently
2703  // registered first.
2704  size_t currentCount = results.size();
2705 
2706  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2707  std::defer_lock);
2708 
2709  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2710  if (std::optional<LogicalResult> result = converter(t, results)) {
2712  cacheWriteLock.lock();
2713  if (!succeeded(*result)) {
2714  cachedDirectConversions.try_emplace(t, nullptr);
2715  return failure();
2716  }
2717  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2718  if (newTypes.size() == 1)
2719  cachedDirectConversions.try_emplace(t, newTypes.front());
2720  else
2721  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2722  return success();
2723  }
2724  }
2725  return failure();
2726 }
2727 
2729  // Use the multi-type result version to convert the type.
2730  SmallVector<Type, 1> results;
2731  if (failed(convertType(t, results)))
2732  return nullptr;
2733 
2734  // Check to ensure that only one type was produced.
2735  return results.size() == 1 ? results.front() : nullptr;
2736 }
2737 
2738 LogicalResult
2740  SmallVectorImpl<Type> &results) const {
2741  for (Type type : types)
2742  if (failed(convertType(type, results)))
2743  return failure();
2744  return success();
2745 }
2746 
2747 bool TypeConverter::isLegal(Type type) const {
2748  return convertType(type) == type;
2749 }
2751  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2752 }
2753 
2754 bool TypeConverter::isLegal(Region *region) const {
2755  return llvm::all_of(*region, [this](Block &block) {
2756  return isLegal(block.getArgumentTypes());
2757  });
2758 }
2759 
2760 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2761  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2762 }
2763 
2764 LogicalResult
2766  SignatureConversion &result) const {
2767  // Try to convert the given input type.
2768  SmallVector<Type, 1> convertedTypes;
2769  if (failed(convertType(type, convertedTypes)))
2770  return failure();
2771 
2772  // If this argument is being dropped, there is nothing left to do.
2773  if (convertedTypes.empty())
2774  return success();
2775 
2776  // Otherwise, add the new inputs.
2777  result.addInputs(inputNo, convertedTypes);
2778  return success();
2779 }
2780 LogicalResult
2782  SignatureConversion &result,
2783  unsigned origInputOffset) const {
2784  for (unsigned i = 0, e = types.size(); i != e; ++i)
2785  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2786  return failure();
2787  return success();
2788 }
2789 
2790 Value TypeConverter::materializeConversion(
2791  ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
2792  Location loc, Type resultType, ValueRange inputs) const {
2793  for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
2794  if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
2795  return *result;
2796  return nullptr;
2797 }
2798 
2799 std::optional<TypeConverter::SignatureConversion>
2801  SignatureConversion conversion(block->getNumArguments());
2802  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2803  return std::nullopt;
2804  return conversion;
2805 }
2806 
2807 //===----------------------------------------------------------------------===//
2808 // Type attribute conversion
2809 //===----------------------------------------------------------------------===//
2812  return AttributeConversionResult(attr, resultTag);
2813 }
2814 
2817  return AttributeConversionResult(nullptr, naTag);
2818 }
2819 
2822  return AttributeConversionResult(nullptr, abortTag);
2823 }
2824 
2826  return impl.getInt() == resultTag;
2827 }
2828 
2830  return impl.getInt() == naTag;
2831 }
2832 
2834  return impl.getInt() == abortTag;
2835 }
2836 
2838  assert(hasResult() && "Cannot get result from N/A or abort");
2839  return impl.getPointer();
2840 }
2841 
2842 std::optional<Attribute>
2844  for (const TypeAttributeConversionCallbackFn &fn :
2845  llvm::reverse(typeAttributeConversions)) {
2846  AttributeConversionResult res = fn(type, attr);
2847  if (res.hasResult())
2848  return res.getResult();
2849  if (res.isAbort())
2850  return std::nullopt;
2851  }
2852  return std::nullopt;
2853 }
2854 
2855 //===----------------------------------------------------------------------===//
2856 // FunctionOpInterfaceSignatureConversion
2857 //===----------------------------------------------------------------------===//
2858 
2859 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
2860  const TypeConverter &typeConverter,
2861  ConversionPatternRewriter &rewriter) {
2862  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
2863  if (!type)
2864  return failure();
2865 
2866  // Convert the original function types.
2867  TypeConverter::SignatureConversion result(type.getNumInputs());
2868  SmallVector<Type, 1> newResults;
2869  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
2870  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
2871  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
2872  typeConverter, &result)))
2873  return failure();
2874 
2875  // Update the function signature in-place.
2876  auto newType = FunctionType::get(rewriter.getContext(),
2877  result.getConvertedTypes(), newResults);
2878 
2879  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
2880 
2881  return success();
2882 }
2883 
2884 /// Create a default conversion pattern that rewrites the type signature of a
2885 /// FunctionOpInterface op. This only supports ops which use FunctionType to
2886 /// represent their type.
2887 namespace {
2888 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
2889  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
2890  MLIRContext *ctx,
2891  const TypeConverter &converter)
2892  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
2893 
2894  LogicalResult
2895  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
2896  ConversionPatternRewriter &rewriter) const override {
2897  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
2898  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
2899  }
2900 };
2901 
2902 struct AnyFunctionOpInterfaceSignatureConversion
2903  : public OpInterfaceConversionPattern<FunctionOpInterface> {
2905 
2906  LogicalResult
2907  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
2908  ConversionPatternRewriter &rewriter) const override {
2909  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
2910  }
2911 };
2912 } // namespace
2913 
2914 FailureOr<Operation *>
2916  const TypeConverter &converter,
2917  ConversionPatternRewriter &rewriter) {
2918  assert(op && "Invalid op");
2919  Location loc = op->getLoc();
2920  if (converter.isLegal(op))
2921  return rewriter.notifyMatchFailure(loc, "op already legal");
2922 
2923  OperationState newOp(loc, op->getName());
2924  newOp.addOperands(operands);
2925 
2926  SmallVector<Type> newResultTypes;
2927  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
2928  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
2929 
2930  newOp.addTypes(newResultTypes);
2931  newOp.addAttributes(op->getAttrs());
2932  return rewriter.create(newOp);
2933 }
2934 
2936  StringRef functionLikeOpName, RewritePatternSet &patterns,
2937  const TypeConverter &converter) {
2938  patterns.add<FunctionOpInterfaceSignatureConversion>(
2939  functionLikeOpName, patterns.getContext(), converter);
2940 }
2941 
2943  RewritePatternSet &patterns, const TypeConverter &converter) {
2944  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
2945  converter, patterns.getContext());
2946 }
2947 
2948 //===----------------------------------------------------------------------===//
2949 // ConversionTarget
2950 //===----------------------------------------------------------------------===//
2951 
2953  LegalizationAction action) {
2954  legalOperations[op].action = action;
2955 }
2956 
2958  LegalizationAction action) {
2959  for (StringRef dialect : dialectNames)
2960  legalDialects[dialect] = action;
2961 }
2962 
2964  -> std::optional<LegalizationAction> {
2965  std::optional<LegalizationInfo> info = getOpInfo(op);
2966  return info ? info->action : std::optional<LegalizationAction>();
2967 }
2968 
2970  -> std::optional<LegalOpDetails> {
2971  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
2972  if (!info)
2973  return std::nullopt;
2974 
2975  // Returns true if this operation instance is known to be legal.
2976  auto isOpLegal = [&] {
2977  // Handle dynamic legality either with the provided legality function.
2978  if (info->action == LegalizationAction::Dynamic) {
2979  std::optional<bool> result = info->legalityFn(op);
2980  if (result)
2981  return *result;
2982  }
2983 
2984  // Otherwise, the operation is only legal if it was marked 'Legal'.
2985  return info->action == LegalizationAction::Legal;
2986  };
2987  if (!isOpLegal())
2988  return std::nullopt;
2989 
2990  // This operation is legal, compute any additional legality information.
2991  LegalOpDetails legalityDetails;
2992  if (info->isRecursivelyLegal) {
2993  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
2994  if (legalityFnIt != opRecursiveLegalityFns.end()) {
2995  legalityDetails.isRecursivelyLegal =
2996  legalityFnIt->second(op).value_or(true);
2997  } else {
2998  legalityDetails.isRecursivelyLegal = true;
2999  }
3000  }
3001  return legalityDetails;
3002 }
3003 
3005  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3006  if (!info)
3007  return false;
3008 
3009  if (info->action == LegalizationAction::Dynamic) {
3010  std::optional<bool> result = info->legalityFn(op);
3011  if (!result)
3012  return false;
3013 
3014  return !(*result);
3015  }
3016 
3017  return info->action == LegalizationAction::Illegal;
3018 }
3019 
3023  if (!oldCallback)
3024  return newCallback;
3025 
3026  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3027  Operation *op) -> std::optional<bool> {
3028  if (std::optional<bool> result = newCl(op))
3029  return *result;
3030 
3031  return oldCl(op);
3032  };
3033  return chain;
3034 }
3035 
3036 void ConversionTarget::setLegalityCallback(
3037  OperationName name, const DynamicLegalityCallbackFn &callback) {
3038  assert(callback && "expected valid legality callback");
3039  auto *infoIt = legalOperations.find(name);
3040  assert(infoIt != legalOperations.end() &&
3041  infoIt->second.action == LegalizationAction::Dynamic &&
3042  "expected operation to already be marked as dynamically legal");
3043  infoIt->second.legalityFn =
3044  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3045 }
3046 
3048  OperationName name, const DynamicLegalityCallbackFn &callback) {
3049  auto *infoIt = legalOperations.find(name);
3050  assert(infoIt != legalOperations.end() &&
3051  infoIt->second.action != LegalizationAction::Illegal &&
3052  "expected operation to already be marked as legal");
3053  infoIt->second.isRecursivelyLegal = true;
3054  if (callback)
3055  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3056  std::move(opRecursiveLegalityFns[name]), callback);
3057  else
3058  opRecursiveLegalityFns.erase(name);
3059 }
3060 
3061 void ConversionTarget::setLegalityCallback(
3062  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3063  assert(callback && "expected valid legality callback");
3064  for (StringRef dialect : dialects)
3065  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3066  std::move(dialectLegalityFns[dialect]), callback);
3067 }
3068 
3069 void ConversionTarget::setLegalityCallback(
3070  const DynamicLegalityCallbackFn &callback) {
3071  assert(callback && "expected valid legality callback");
3072  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3073 }
3074 
3075 auto ConversionTarget::getOpInfo(OperationName op) const
3076  -> std::optional<LegalizationInfo> {
3077  // Check for info for this specific operation.
3078  const auto *it = legalOperations.find(op);
3079  if (it != legalOperations.end())
3080  return it->second;
3081  // Check for info for the parent dialect.
3082  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3083  if (dialectIt != legalDialects.end()) {
3084  DynamicLegalityCallbackFn callback;
3085  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3086  if (dialectFn != dialectLegalityFns.end())
3087  callback = dialectFn->second;
3088  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3089  callback};
3090  }
3091  // Otherwise, check if we mark unknown operations as dynamic.
3092  if (unknownLegalityFn)
3093  return LegalizationInfo{LegalizationAction::Dynamic,
3094  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3095  return std::nullopt;
3096 }
3097 
3098 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3099 //===----------------------------------------------------------------------===//
3100 // PDL Configuration
3101 //===----------------------------------------------------------------------===//
3102 
3104  auto &rewriterImpl =
3105  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3106  rewriterImpl.currentTypeConverter = getTypeConverter();
3107 }
3108 
3110  auto &rewriterImpl =
3111  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3112  rewriterImpl.currentTypeConverter = nullptr;
3113 }
3114 
3115 /// Remap the given value using the rewriter and the type converter in the
3116 /// provided config.
3117 static FailureOr<SmallVector<Value>>
3119  SmallVector<Value> mappedValues;
3120  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3121  return failure();
3122  return std::move(mappedValues);
3123 }
3124 
3126  patterns.getPDLPatterns().registerRewriteFunction(
3127  "convertValue",
3128  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3129  auto results = pdllConvertValues(
3130  static_cast<ConversionPatternRewriter &>(rewriter), value);
3131  if (failed(results))
3132  return failure();
3133  return results->front();
3134  });
3135  patterns.getPDLPatterns().registerRewriteFunction(
3136  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3137  return pdllConvertValues(
3138  static_cast<ConversionPatternRewriter &>(rewriter), values);
3139  });
3140  patterns.getPDLPatterns().registerRewriteFunction(
3141  "convertType",
3142  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3143  auto &rewriterImpl =
3144  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3145  if (const TypeConverter *converter =
3146  rewriterImpl.currentTypeConverter) {
3147  if (Type newType = converter->convertType(type))
3148  return newType;
3149  return failure();
3150  }
3151  return type;
3152  });
3153  patterns.getPDLPatterns().registerRewriteFunction(
3154  "convertTypes",
3155  [](PatternRewriter &rewriter,
3156  TypeRange types) -> FailureOr<SmallVector<Type>> {
3157  auto &rewriterImpl =
3158  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3159  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3160  if (!converter)
3161  return SmallVector<Type>(types);
3162 
3163  SmallVector<Type> remappedTypes;
3164  if (failed(converter->convertTypes(types, remappedTypes)))
3165  return failure();
3166  return std::move(remappedTypes);
3167  });
3168 }
3169 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3170 
3171 //===----------------------------------------------------------------------===//
3172 // Op Conversion Entry Points
3173 //===----------------------------------------------------------------------===//
3174 
3175 //===----------------------------------------------------------------------===//
3176 // Partial Conversion
3177 
3179  ArrayRef<Operation *> ops, const ConversionTarget &target,
3180  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3181  OperationConverter opConverter(target, patterns, config,
3182  OpConversionMode::Partial);
3183  return opConverter.convertOperations(ops);
3184 }
3185 LogicalResult
3187  const FrozenRewritePatternSet &patterns,
3188  ConversionConfig config) {
3189  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3190 }
3191 
3192 //===----------------------------------------------------------------------===//
3193 // Full Conversion
3194 
3196  const ConversionTarget &target,
3197  const FrozenRewritePatternSet &patterns,
3198  ConversionConfig config) {
3199  OperationConverter opConverter(target, patterns, config,
3200  OpConversionMode::Full);
3201  return opConverter.convertOperations(ops);
3202 }
3204  const ConversionTarget &target,
3205  const FrozenRewritePatternSet &patterns,
3206  ConversionConfig config) {
3207  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3208 }
3209 
3210 //===----------------------------------------------------------------------===//
3211 // Analysis Conversion
3212 
3213 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3214 /// op is a top-level module op (which is expected to be isolated from above),
3215 /// return that op.
3217  // Check if there is a top-level operation within `ops`. If so, return that
3218  // op.
3219  for (Operation *op : ops) {
3220  if (!op->getParentOp()) {
3221 #ifndef NDEBUG
3222  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3223  "expected top-level op to be isolated from above");
3224  for (Operation *other : ops)
3225  assert(op->isAncestor(other) &&
3226  "expected ops to have a common ancestor");
3227 #endif // NDEBUG
3228  return op;
3229  }
3230  }
3231 
3232  // No top-level op. Find a common ancestor.
3233  Operation *commonAncestor =
3234  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3235  for (Operation *op : ops.drop_front()) {
3236  while (!commonAncestor->isProperAncestor(op)) {
3237  commonAncestor =
3239  assert(commonAncestor &&
3240  "expected to find a common isolated from above ancestor");
3241  }
3242  }
3243 
3244  return commonAncestor;
3245 }
3246 
3249  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3250 #ifndef NDEBUG
3251  if (config.legalizableOps)
3252  assert(config.legalizableOps->empty() && "expected empty set");
3253 #endif // NDEBUG
3254 
3255  // Clone closted common ancestor that is isolated from above.
3256  Operation *commonAncestor = findCommonAncestor(ops);
3257  IRMapping mapping;
3258  Operation *clonedAncestor = commonAncestor->clone(mapping);
3259  // Compute inverse IR mapping.
3260  DenseMap<Operation *, Operation *> inverseOperationMap;
3261  for (auto &it : mapping.getOperationMap())
3262  inverseOperationMap[it.second] = it.first;
3263 
3264  // Convert the cloned operations. The original IR will remain unchanged.
3265  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3266  ops, [&](Operation *op) { return mapping.lookup(op); });
3267  OperationConverter opConverter(target, patterns, config,
3268  OpConversionMode::Analysis);
3269  LogicalResult status = opConverter.convertOperations(opsToConvert);
3270 
3271  // Remap `legalizableOps`, so that they point to the original ops and not the
3272  // cloned ops.
3273  if (config.legalizableOps) {
3274  DenseSet<Operation *> originalLegalizableOps;
3275  for (Operation *op : *config.legalizableOps)
3276  originalLegalizableOps.insert(inverseOperationMap[op]);
3277  *config.legalizableOps = std::move(originalLegalizableOps);
3278  }
3279 
3280  // Erase the cloned IR.
3281  clonedAncestor->erase();
3282  return status;
3283 }
3284 
3285 LogicalResult
3287  const FrozenRewritePatternSet &patterns,
3288  ConversionConfig config) {
3289  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3290 }
static std::pair< ValueRange, const TypeConverter * > getReplacedValues(IRRewrite *rewrite)
Helper function that returns the replaced values and the type converter if the given rewrite object i...
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static bool hasRewrite(R &&rewrites, Operation *op)
Return "true" if there is an operation rewrite that matches the specified rewrite type and operation ...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
bool empty()
Definition: Block.h:146
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:93
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator end()
Definition: Block.h:142
iterator begin()
Definition: Block.h:141
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:335
Block::iterator getPoint() const
Definition: Builders.h:348
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:345
Block * getBlock() const
Definition: Builders.h:347
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:328
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:461
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:488
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:503
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:605
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:892
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
MLIRContext * getContext() const
Definition: PatternMatch.h:823
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_iterator user_end() const
Definition: Value.h:227
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition: Argument.h:64
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target/ argument materializations t...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a single value that remaps an existing signature input...
A rewriter that keeps track of erased ops and blocks.
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ValueRange newValues)
Notifies that an op is about to be replaced with the given values.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
Value buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueRange inputs, Type outputType, const TypeConverter *converter)
Build an unresolved materialization operation given an output type and set of input operands.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
SingleEraseRewriter eraseRewriter
A rewriter that keeps track of ops/block that were already erased and skips duplicate op/block erasur...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.