MLIR  19.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Iterators.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::detail;
29 
30 #define DEBUG_TYPE "dialect-conversion"
31 
32 /// A utility function to log a successful result for the given reason.
33 template <typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
35  LLVM_DEBUG({
36  os.unindent();
37  os.startLine() << "} -> SUCCESS";
38  if (!fmt.empty())
39  os.getOStream() << " : "
40  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41  os.getOStream() << "\n";
42  });
43 }
44 
45 /// A utility function to log a failure result for the given reason.
46 template <typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
48  LLVM_DEBUG({
49  os.unindent();
50  os.startLine() << "} -> FAILURE : "
51  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
52  << "\n";
53  });
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // ConversionValueMapping
58 //===----------------------------------------------------------------------===//
59 
60 namespace {
61 /// This class wraps a IRMapping to provide recursive lookup
62 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
63 struct ConversionValueMapping {
64  /// Lookup a mapped value within the map. If a mapping for the provided value
65  /// does not exist then return the provided value. If `desiredType` is
66  /// non-null, returns the most recently mapped value with that type. If an
67  /// operand of that type does not exist, defaults to normal behavior.
68  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
69 
70  /// Lookup a mapped value within the map, or return null if a mapping does not
71  /// exist. If a mapping exists, this follows the same behavior of
72  /// `lookupOrDefault`.
73  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
74 
75  /// Map a value to the one provided.
76  void map(Value oldVal, Value newVal) {
77  LLVM_DEBUG({
78  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
79  assert(it != oldVal && "inserting cyclic mapping");
80  });
81  mapping.map(oldVal, newVal);
82  }
83 
84  /// Try to map a value to the one provided. Returns false if a transitive
85  /// mapping from the new value to the old value already exists, true if the
86  /// map was updated.
87  bool tryMap(Value oldVal, Value newVal);
88 
89  /// Drop the last mapping for the given value.
90  void erase(Value value) { mapping.erase(value); }
91 
92  /// Returns the inverse raw value mapping (without recursive query support).
93  DenseMap<Value, SmallVector<Value>> getInverse() const {
95  for (auto &it : mapping.getValueMap())
96  inverse[it.second].push_back(it.first);
97  return inverse;
98  }
99 
100 private:
101  /// Current value mappings.
102  IRMapping mapping;
103 };
104 } // namespace
105 
106 Value ConversionValueMapping::lookupOrDefault(Value from,
107  Type desiredType) const {
108  // If there was no desired type, simply find the leaf value.
109  if (!desiredType) {
110  // If this value had a valid mapping, unmap that value as well in the case
111  // that it was also replaced.
112  while (auto mappedValue = mapping.lookupOrNull(from))
113  from = mappedValue;
114  return from;
115  }
116 
117  // Otherwise, try to find the deepest value that has the desired type.
118  Value desiredValue;
119  do {
120  if (from.getType() == desiredType)
121  desiredValue = from;
122 
123  Value mappedValue = mapping.lookupOrNull(from);
124  if (!mappedValue)
125  break;
126  from = mappedValue;
127  } while (true);
128 
129  // If the desired value was found use it, otherwise default to the leaf value.
130  return desiredValue ? desiredValue : from;
131 }
132 
133 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
134  Value result = lookupOrDefault(from, desiredType);
135  if (result == from || (desiredType && result.getType() != desiredType))
136  return nullptr;
137  return result;
138 }
139 
140 bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
141  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
142  if (it == oldVal)
143  return false;
144  map(oldVal, newVal);
145  return true;
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // Rewriter and Translation State
150 //===----------------------------------------------------------------------===//
151 namespace {
152 /// This class contains a snapshot of the current conversion rewriter state.
153 /// This is useful when saving and undoing a set of rewrites.
154 struct RewriterState {
155  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
156  unsigned numReplacedOps)
157  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
158  numReplacedOps(numReplacedOps) {}
159 
160  /// The current number of rewrites performed.
161  unsigned numRewrites;
162 
163  /// The current number of ignored operations.
164  unsigned numIgnoredOperations;
165 
166  /// The current number of replaced ops that are scheduled for erasure.
167  unsigned numReplacedOps;
168 };
169 
170 //===----------------------------------------------------------------------===//
171 // IR rewrites
172 //===----------------------------------------------------------------------===//
173 
174 /// An IR rewrite that can be committed (upon success) or rolled back (upon
175 /// failure).
176 ///
177 /// The dialect conversion keeps track of IR modifications (requested by the
178 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
179 /// are directly applied to the IR as the rewriter API is used, some are applied
180 /// partially, and some are delayed until the `IRRewrite` objects are committed.
181 class IRRewrite {
182 public:
183  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
184  /// Enum values are ordered, so that they can be used in `classof`: first all
185  /// block rewrites, then all operation rewrites.
186  enum class Kind {
187  // Block rewrites
188  CreateBlock,
189  EraseBlock,
190  InlineBlock,
191  MoveBlock,
192  BlockTypeConversion,
193  ReplaceBlockArg,
194  // Operation rewrites
195  MoveOperation,
196  ModifyOperation,
197  ReplaceOperation,
198  CreateOperation,
199  UnresolvedMaterialization
200  };
201 
202  virtual ~IRRewrite() = default;
203 
204  /// Roll back the rewrite. Operations may be erased during rollback.
205  virtual void rollback() = 0;
206 
207  /// Commit the rewrite. At this point, it is certain that the dialect
208  /// conversion will succeed. All IR modifications, except for operation/block
209  /// erasure, must be performed through the given rewriter.
210  ///
211  /// Instead of erasing operations/blocks, they should merely be unlinked
212  /// commit phase and finally be erased during the cleanup phase. This is
213  /// because internal dialect conversion state (such as `mapping`) may still
214  /// be using them.
215  ///
216  /// Any IR modification that was already performed before the commit phase
217  /// (e.g., insertion of an op) must be communicated to the listener that may
218  /// be attached to the given rewriter.
219  virtual void commit(RewriterBase &rewriter) {}
220 
221  /// Cleanup operations/blocks. Cleanup is called after commit.
222  virtual void cleanup(RewriterBase &rewriter) {}
223 
224  Kind getKind() const { return kind; }
225 
226  static bool classof(const IRRewrite *rewrite) { return true; }
227 
228 protected:
229  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
230  : kind(kind), rewriterImpl(rewriterImpl) {}
231 
232  const ConversionConfig &getConfig() const;
233 
234  const Kind kind;
235  ConversionPatternRewriterImpl &rewriterImpl;
236 };
237 
238 /// A block rewrite.
239 class BlockRewrite : public IRRewrite {
240 public:
241  /// Return the block that this rewrite operates on.
242  Block *getBlock() const { return block; }
243 
244  static bool classof(const IRRewrite *rewrite) {
245  return rewrite->getKind() >= Kind::CreateBlock &&
246  rewrite->getKind() <= Kind::ReplaceBlockArg;
247  }
248 
249 protected:
250  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
251  Block *block)
252  : IRRewrite(kind, rewriterImpl), block(block) {}
253 
254  // The block that this rewrite operates on.
255  Block *block;
256 };
257 
258 /// Creation of a block. Block creations are immediately reflected in the IR.
259 /// There is no extra work to commit the rewrite. During rollback, the newly
260 /// created block is erased.
261 class CreateBlockRewrite : public BlockRewrite {
262 public:
263  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
264  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
265 
266  static bool classof(const IRRewrite *rewrite) {
267  return rewrite->getKind() == Kind::CreateBlock;
268  }
269 
270  void commit(RewriterBase &rewriter) override {
271  // The block was already created and inserted. Just inform the listener.
272  if (auto *listener = rewriter.getListener())
273  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
274  }
275 
276  void rollback() override {
277  // Unlink all of the operations within this block, they will be deleted
278  // separately.
279  auto &blockOps = block->getOperations();
280  while (!blockOps.empty())
281  blockOps.remove(blockOps.begin());
282  block->dropAllUses();
283  if (block->getParent())
284  block->erase();
285  else
286  delete block;
287  }
288 };
289 
290 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
291 /// blocks are immediately unlinked, but only erased during cleanup. This makes
292 /// it easier to rollback a block erasure: the block is simply inserted into its
293 /// original location.
294 class EraseBlockRewrite : public BlockRewrite {
295 public:
296  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
297  Region *region, Block *insertBeforeBlock)
298  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), region(region),
299  insertBeforeBlock(insertBeforeBlock) {}
300 
301  static bool classof(const IRRewrite *rewrite) {
302  return rewrite->getKind() == Kind::EraseBlock;
303  }
304 
305  ~EraseBlockRewrite() override {
306  assert(!block &&
307  "rewrite was neither rolled back nor committed/cleaned up");
308  }
309 
310  void rollback() override {
311  // The block (owned by this rewrite) was not actually erased yet. It was
312  // just unlinked. Put it back into its original position.
313  assert(block && "expected block");
314  auto &blockList = region->getBlocks();
315  Region::iterator before = insertBeforeBlock
316  ? Region::iterator(insertBeforeBlock)
317  : blockList.end();
318  blockList.insert(before, block);
319  block = nullptr;
320  }
321 
322  void commit(RewriterBase &rewriter) override {
323  // Erase the block.
324  assert(block && "expected block");
325  assert(block->empty() && "expected empty block");
326 
327  // Notify the listener that the block is about to be erased.
328  if (auto *listener =
329  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
330  listener->notifyBlockErased(block);
331  }
332 
333  void cleanup(RewriterBase &rewriter) override {
334  // Erase the block.
335  block->dropAllDefinedValueUses();
336  delete block;
337  block = nullptr;
338  }
339 
340 private:
341  // The region in which this block was previously contained.
342  Region *region;
343 
344  // The original successor of this block before it was unlinked. "nullptr" if
345  // this block was the only block in the region.
346  Block *insertBeforeBlock;
347 };
348 
349 /// Inlining of a block. This rewrite is immediately reflected in the IR.
350 /// Note: This rewrite represents only the inlining of the operations. The
351 /// erasure of the inlined block is a separate rewrite.
352 class InlineBlockRewrite : public BlockRewrite {
353 public:
354  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
355  Block *sourceBlock, Block::iterator before)
356  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
357  sourceBlock(sourceBlock),
358  firstInlinedInst(sourceBlock->empty() ? nullptr
359  : &sourceBlock->front()),
360  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
361  // If a listener is attached to the dialect conversion, ops must be moved
362  // one-by-one. When they are moved in bulk, notifications cannot be sent
363  // because the ops that used to be in the source block at the time of the
364  // inlining (before the "commit" phase) are unknown at the time when
365  // notifications are sent (which is during the "commit" phase).
366  assert(!getConfig().listener &&
367  "InlineBlockRewrite not supported if listener is attached");
368  }
369 
370  static bool classof(const IRRewrite *rewrite) {
371  return rewrite->getKind() == Kind::InlineBlock;
372  }
373 
374  void rollback() override {
375  // Put the operations from the destination block (owned by the rewrite)
376  // back into the source block.
377  if (firstInlinedInst) {
378  assert(lastInlinedInst && "expected operation");
379  sourceBlock->getOperations().splice(sourceBlock->begin(),
380  block->getOperations(),
381  Block::iterator(firstInlinedInst),
382  ++Block::iterator(lastInlinedInst));
383  }
384  }
385 
386 private:
387  // The block that originally contained the operations.
388  Block *sourceBlock;
389 
390  // The first inlined operation.
391  Operation *firstInlinedInst;
392 
393  // The last inlined operation.
394  Operation *lastInlinedInst;
395 };
396 
397 /// Moving of a block. This rewrite is immediately reflected in the IR.
398 class MoveBlockRewrite : public BlockRewrite {
399 public:
400  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
401  Region *region, Block *insertBeforeBlock)
402  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
403  insertBeforeBlock(insertBeforeBlock) {}
404 
405  static bool classof(const IRRewrite *rewrite) {
406  return rewrite->getKind() == Kind::MoveBlock;
407  }
408 
409  void commit(RewriterBase &rewriter) override {
410  // The block was already moved. Just inform the listener.
411  if (auto *listener = rewriter.getListener()) {
412  // Note: `previousIt` cannot be passed because this is a delayed
413  // notification and iterators into past IR state cannot be represented.
414  listener->notifyBlockInserted(block, /*previous=*/region,
415  /*previousIt=*/{});
416  }
417  }
418 
419  void rollback() override {
420  // Move the block back to its original position.
421  Region::iterator before =
422  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
423  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
424  }
425 
426 private:
427  // The region in which this block was previously contained.
428  Region *region;
429 
430  // The original successor of this block before it was moved. "nullptr" if
431  // this block was the only block in the region.
432  Block *insertBeforeBlock;
433 };
434 
435 /// This structure contains the information pertaining to an argument that has
436 /// been converted.
437 struct ConvertedArgInfo {
438  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
439  Value castValue = nullptr)
440  : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
441 
442  /// The start index of in the new argument list that contains arguments that
443  /// replace the original.
444  unsigned newArgIdx;
445 
446  /// The number of arguments that replaced the original argument.
447  unsigned newArgSize;
448 
449  /// The cast value that was created to cast from the new arguments to the
450  /// old. This only used if 'newArgSize' > 1.
451  Value castValue;
452 };
453 
454 /// Block type conversion. This rewrite is partially reflected in the IR.
455 class BlockTypeConversionRewrite : public BlockRewrite {
456 public:
457  BlockTypeConversionRewrite(
458  ConversionPatternRewriterImpl &rewriterImpl, Block *block,
459  Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
460  const TypeConverter *converter)
461  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
462  origBlock(origBlock), argInfo(argInfo), converter(converter) {}
463 
464  static bool classof(const IRRewrite *rewrite) {
465  return rewrite->getKind() == Kind::BlockTypeConversion;
466  }
467 
468  /// Materialize any necessary conversions for converted arguments that have
469  /// live users, using the provided `findLiveUser` to search for a user that
470  /// survives the conversion process.
471  LogicalResult
472  materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
473 
474  void commit(RewriterBase &rewriter) override;
475 
476  void rollback() override;
477 
478 private:
479  /// The original block that was requested to have its signature converted.
480  Block *origBlock;
481 
482  /// The conversion information for each of the arguments. The information is
483  /// std::nullopt if the argument was dropped during conversion.
485 
486  /// The type converter used to convert the arguments.
487  const TypeConverter *converter;
488 };
489 
490 /// Replacing a block argument. This rewrite is not immediately reflected in the
491 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
492 /// until the rewrite is committed.
493 class ReplaceBlockArgRewrite : public BlockRewrite {
494 public:
495  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
496  Block *block, BlockArgument arg)
497  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
498 
499  static bool classof(const IRRewrite *rewrite) {
500  return rewrite->getKind() == Kind::ReplaceBlockArg;
501  }
502 
503  void commit(RewriterBase &rewriter) override;
504 
505  void rollback() override;
506 
507 private:
508  BlockArgument arg;
509 };
510 
511 /// An operation rewrite.
512 class OperationRewrite : public IRRewrite {
513 public:
514  /// Return the operation that this rewrite operates on.
515  Operation *getOperation() const { return op; }
516 
517  static bool classof(const IRRewrite *rewrite) {
518  return rewrite->getKind() >= Kind::MoveOperation &&
519  rewrite->getKind() <= Kind::UnresolvedMaterialization;
520  }
521 
522 protected:
523  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
524  Operation *op)
525  : IRRewrite(kind, rewriterImpl), op(op) {}
526 
527  // The operation that this rewrite operates on.
528  Operation *op;
529 };
530 
531 /// Moving of an operation. This rewrite is immediately reflected in the IR.
532 class MoveOperationRewrite : public OperationRewrite {
533 public:
534  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
535  Operation *op, Block *block, Operation *insertBeforeOp)
536  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
537  insertBeforeOp(insertBeforeOp) {}
538 
539  static bool classof(const IRRewrite *rewrite) {
540  return rewrite->getKind() == Kind::MoveOperation;
541  }
542 
543  void commit(RewriterBase &rewriter) override {
544  // The operation was already moved. Just inform the listener.
545  if (auto *listener = rewriter.getListener()) {
546  // Note: `previousIt` cannot be passed because this is a delayed
547  // notification and iterators into past IR state cannot be represented.
548  listener->notifyOperationInserted(
549  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
550  /*insertPt=*/{}));
551  }
552  }
553 
554  void rollback() override {
555  // Move the operation back to its original position.
556  Block::iterator before =
557  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
558  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
559  }
560 
561 private:
562  // The block in which this operation was previously contained.
563  Block *block;
564 
565  // The original successor of this operation before it was moved. "nullptr"
566  // if this operation was the only operation in the region.
567  Operation *insertBeforeOp;
568 };
569 
570 /// In-place modification of an op. This rewrite is immediately reflected in
571 /// the IR. The previous state of the operation is stored in this object.
572 class ModifyOperationRewrite : public OperationRewrite {
573 public:
574  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
575  Operation *op)
576  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
577  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
578  operands(op->operand_begin(), op->operand_end()),
579  successors(op->successor_begin(), op->successor_end()) {
580  if (OpaqueProperties prop = op->getPropertiesStorage()) {
581  // Make a copy of the properties.
582  propertiesStorage = operator new(op->getPropertiesStorageSize());
583  OpaqueProperties propCopy(propertiesStorage);
584  name.initOpProperties(propCopy, /*init=*/prop);
585  }
586  }
587 
588  static bool classof(const IRRewrite *rewrite) {
589  return rewrite->getKind() == Kind::ModifyOperation;
590  }
591 
592  ~ModifyOperationRewrite() override {
593  assert(!propertiesStorage &&
594  "rewrite was neither committed nor rolled back");
595  }
596 
597  void commit(RewriterBase &rewriter) override {
598  // Notify the listener that the operation was modified in-place.
599  if (auto *listener =
600  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
601  listener->notifyOperationModified(op);
602 
603  if (propertiesStorage) {
604  OpaqueProperties propCopy(propertiesStorage);
605  // Note: The operation may have been erased in the mean time, so
606  // OperationName must be stored in this object.
607  name.destroyOpProperties(propCopy);
608  operator delete(propertiesStorage);
609  propertiesStorage = nullptr;
610  }
611  }
612 
613  void rollback() override {
614  op->setLoc(loc);
615  op->setAttrs(attrs);
616  op->setOperands(operands);
617  for (const auto &it : llvm::enumerate(successors))
618  op->setSuccessor(it.value(), it.index());
619  if (propertiesStorage) {
620  OpaqueProperties propCopy(propertiesStorage);
621  op->copyProperties(propCopy);
622  name.destroyOpProperties(propCopy);
623  operator delete(propertiesStorage);
624  propertiesStorage = nullptr;
625  }
626  }
627 
628 private:
629  OperationName name;
630  LocationAttr loc;
631  DictionaryAttr attrs;
632  SmallVector<Value, 8> operands;
633  SmallVector<Block *, 2> successors;
634  void *propertiesStorage = nullptr;
635 };
636 
637 /// Replacing an operation. Erasing an operation is treated as a special case
638 /// with "null" replacements. This rewrite is not immediately reflected in the
639 /// IR. An internal IR mapping is updated, but values are not replaced and the
640 /// original op is not erased until the rewrite is committed.
641 class ReplaceOperationRewrite : public OperationRewrite {
642 public:
643  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
644  Operation *op, const TypeConverter *converter,
645  bool changedResults)
646  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
647  converter(converter), changedResults(changedResults) {}
648 
649  static bool classof(const IRRewrite *rewrite) {
650  return rewrite->getKind() == Kind::ReplaceOperation;
651  }
652 
653  void commit(RewriterBase &rewriter) override;
654 
655  void rollback() override;
656 
657  void cleanup(RewriterBase &rewriter) override;
658 
659  const TypeConverter *getConverter() const { return converter; }
660 
661  bool hasChangedResults() const { return changedResults; }
662 
663 private:
664  /// An optional type converter that can be used to materialize conversions
665  /// between the new and old values if necessary.
666  const TypeConverter *converter;
667 
668  /// A boolean flag that indicates whether result types have changed or not.
669  bool changedResults;
670 };
671 
672 class CreateOperationRewrite : public OperationRewrite {
673 public:
674  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
675  Operation *op)
676  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
677 
678  static bool classof(const IRRewrite *rewrite) {
679  return rewrite->getKind() == Kind::CreateOperation;
680  }
681 
682  void commit(RewriterBase &rewriter) override {
683  // The operation was already created and inserted. Just inform the listener.
684  if (auto *listener = rewriter.getListener())
685  listener->notifyOperationInserted(op, /*previous=*/{});
686  }
687 
688  void rollback() override;
689 };
690 
691 /// The type of materialization.
692 enum MaterializationKind {
693  /// This materialization materializes a conversion for an illegal block
694  /// argument type, to a legal one.
695  Argument,
696 
697  /// This materialization materializes a conversion from an illegal type to a
698  /// legal one.
699  Target
700 };
701 
702 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
703 /// op. Unresolved materializations are erased at the end of the dialect
704 /// conversion.
705 class UnresolvedMaterializationRewrite : public OperationRewrite {
706 public:
707  UnresolvedMaterializationRewrite(
708  ConversionPatternRewriterImpl &rewriterImpl,
709  UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
710  MaterializationKind kind = MaterializationKind::Target)
711  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
712  converterAndKind(converter, kind) {}
713 
714  static bool classof(const IRRewrite *rewrite) {
715  return rewrite->getKind() == Kind::UnresolvedMaterialization;
716  }
717 
718  UnrealizedConversionCastOp getOperation() const {
719  return cast<UnrealizedConversionCastOp>(op);
720  }
721 
722  void rollback() override;
723 
724  void cleanup(RewriterBase &rewriter) override;
725 
726  /// Return the type converter of this materialization (which may be null).
727  const TypeConverter *getConverter() const {
728  return converterAndKind.getPointer();
729  }
730 
731  /// Return the kind of this materialization.
732  MaterializationKind getMaterializationKind() const {
733  return converterAndKind.getInt();
734  }
735 
736 private:
737  /// The corresponding type converter to use when resolving this
738  /// materialization, and the kind of this materialization.
739  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
740  converterAndKind;
741 };
742 } // namespace
743 
744 /// Return "true" if there is an operation rewrite that matches the specified
745 /// rewrite type and operation among the given rewrites.
746 template <typename RewriteTy, typename R>
747 static bool hasRewrite(R &&rewrites, Operation *op) {
748  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
749  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
750  return rewriteTy && rewriteTy->getOperation() == op;
751  });
752 }
753 
754 /// Find the single rewrite object of the specified type and block among the
755 /// given rewrites. In debug mode, asserts that there is mo more than one such
756 /// object. Return "nullptr" if no object was found.
757 template <typename RewriteTy, typename R>
758 static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
759  RewriteTy *result = nullptr;
760  for (auto &rewrite : rewrites) {
761  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
762  if (rewriteTy && rewriteTy->getBlock() == block) {
763 #ifndef NDEBUG
764  assert(!result && "expected single matching rewrite");
765  result = rewriteTy;
766 #else
767  return rewriteTy;
768 #endif // NDEBUG
769  }
770  }
771  return result;
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // ConversionPatternRewriterImpl
776 //===----------------------------------------------------------------------===//
777 namespace mlir {
778 namespace detail {
781  const ConversionConfig &config)
782  : context(ctx), config(config) {}
783 
784  //===--------------------------------------------------------------------===//
785  // State Management
786  //===--------------------------------------------------------------------===//
787 
788  /// Return the current state of the rewriter.
789  RewriterState getCurrentState();
790 
791  /// Apply all requested operation rewrites. This method is invoked when the
792  /// conversion process succeeds.
793  void applyRewrites();
794 
795  /// Reset the state of the rewriter to a previously saved point.
796  void resetState(RewriterState state);
797 
798  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
799  /// failure.
800  template <typename RewriteTy, typename... Args>
801  void appendRewrite(Args &&...args) {
802  rewrites.push_back(
803  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
804  }
805 
806  /// Undo the rewrites (motions, splits) one by one in reverse order until
807  /// "numRewritesToKeep" rewrites remains.
808  void undoRewrites(unsigned numRewritesToKeep = 0);
809 
810  /// Remap the given values to those with potentially different types. Returns
811  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
812  /// is the tag used when describing a value within a diagnostic, e.g.
813  /// "operand".
814  LogicalResult remapValues(StringRef valueDiagTag,
815  std::optional<Location> inputLoc,
816  PatternRewriter &rewriter, ValueRange values,
817  SmallVectorImpl<Value> &remapped);
818 
819  /// Return "true" if the given operation is ignored, and does not need to be
820  /// converted.
821  bool isOpIgnored(Operation *op) const;
822 
823  /// Return "true" if the given operation was replaced or erased.
824  bool wasOpReplaced(Operation *op) const;
825 
826  //===--------------------------------------------------------------------===//
827  // Type Conversion
828  //===--------------------------------------------------------------------===//
829 
830  /// Convert the types of block arguments within the given region.
831  FailureOr<Block *>
832  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
833  const TypeConverter &converter,
834  TypeConverter::SignatureConversion *entryConversion);
835 
836  /// Apply the given signature conversion on the given block. The new block
837  /// containing the updated signature is returned. If no conversions were
838  /// necessary, e.g. if the block has no arguments, `block` is returned.
839  /// `converter` is used to generate any necessary cast operations that
840  /// translate between the origin argument types and those specified in the
841  /// signature conversion.
842  Block *applySignatureConversion(
843  ConversionPatternRewriter &rewriter, Block *block,
844  const TypeConverter *converter,
845  TypeConverter::SignatureConversion &signatureConversion);
846 
847  //===--------------------------------------------------------------------===//
848  // Materializations
849  //===--------------------------------------------------------------------===//
850  /// Build an unresolved materialization operation given an output type and set
851  /// of input operands.
852  Value buildUnresolvedMaterialization(MaterializationKind kind,
853  Block *insertBlock,
854  Block::iterator insertPt, Location loc,
855  ValueRange inputs, Type outputType,
856  const TypeConverter *converter);
857 
858  Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
859  ValueRange inputs,
860  Type outputType,
861  const TypeConverter *converter);
862 
863  Value buildUnresolvedTargetMaterialization(Location loc, Value input,
864  Type outputType,
865  const TypeConverter *converter);
866 
867  //===--------------------------------------------------------------------===//
868  // Rewriter Notification Hooks
869  //===--------------------------------------------------------------------===//
870 
871  //// Notifies that an op was inserted.
872  void notifyOperationInserted(Operation *op,
873  OpBuilder::InsertPoint previous) override;
874 
875  /// Notifies that an op is about to be replaced with the given values.
876  void notifyOpReplaced(Operation *op, ValueRange newValues);
877 
878  /// Notifies that a block is about to be erased.
879  void notifyBlockIsBeingErased(Block *block);
880 
881  /// Notifies that a block was inserted.
882  void notifyBlockInserted(Block *block, Region *previous,
883  Region::iterator previousIt) override;
884 
885  /// Notifies that a block is being inlined into another block.
886  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
887  Block::iterator before);
888 
889  /// Notifies that a pattern match failed for the given reason.
890  void
891  notifyMatchFailure(Location loc,
892  function_ref<void(Diagnostic &)> reasonCallback) override;
893 
894  //===--------------------------------------------------------------------===//
895  // IR Erasure
896  //===--------------------------------------------------------------------===//
897 
898  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
899  /// operation or block is erased multiple times. This rewriter assumes that
900  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
902  public:
904  : RewriterBase(context, /*listener=*/this) {}
905 
906  /// Erase the given op (unless it was already erased).
907  void eraseOp(Operation *op) override {
908  if (erased.contains(op))
909  return;
910  op->dropAllUses();
912  }
913 
914  /// Erase the given block (unless it was already erased).
915  void eraseBlock(Block *block) override {
916  if (erased.contains(block))
917  return;
918  assert(block->empty() && "expected empty block");
919  block->dropAllDefinedValueUses();
921  }
922 
923  void notifyOperationErased(Operation *op) override { erased.insert(op); }
924 
925  void notifyBlockErased(Block *block) override { erased.insert(block); }
926 
927  /// Pointers to all erased operations and blocks.
929  };
930 
931  //===--------------------------------------------------------------------===//
932  // State
933  //===--------------------------------------------------------------------===//
934 
935  /// MLIR context.
937 
938  // Mapping between replaced values that differ in type. This happens when
939  // replacing a value with one of a different type.
940  ConversionValueMapping mapping;
941 
942  /// Ordered list of block operations (creations, splits, motions).
944 
945  /// A set of operations that should no longer be considered for legalization.
946  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
947  /// tracked separately.
949 
950  /// A set of operations that were replaced/erased. Such ops are not erased
951  /// immediately but only when the dialect conversion succeeds. In the mean
952  /// time, they should no longer be considered for legalization and any attempt
953  /// to modify/access them is invalid rewriter API usage.
955 
956  /// The current type converter, or nullptr if no type converter is currently
957  /// active.
958  const TypeConverter *currentTypeConverter = nullptr;
959 
960  /// A mapping of regions to type converters that should be used when
961  /// converting the arguments of blocks within that region.
963 
964  /// Dialect conversion configuration.
966 
967 #ifndef NDEBUG
968  /// A set of operations that have pending updates. This tracking isn't
969  /// strictly necessary, and is thus only active during debug builds for extra
970  /// verification.
972 
973  /// A logger used to emit diagnostics during the conversion process.
974  llvm::ScopedPrinter logger{llvm::dbgs()};
975 #endif
976 };
977 } // namespace detail
978 } // namespace mlir
979 
980 const ConversionConfig &IRRewrite::getConfig() const {
981  return rewriterImpl.config;
982 }
983 
984 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
985  // Inform the listener about all IR modifications that have already taken
986  // place: References to the original block have been replaced with the new
987  // block.
988  if (auto *listener =
989  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
990  for (Operation *op : block->getUsers())
991  listener->notifyOperationModified(op);
992 
993  // Process the remapping for each of the original arguments.
994  for (auto [origArg, info] :
995  llvm::zip_equal(origBlock->getArguments(), argInfo)) {
996  // Handle the case of a 1->0 value mapping.
997  if (!info) {
998  if (Value newArg =
999  rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
1000  rewriter.replaceAllUsesWith(origArg, newArg);
1001  continue;
1002  }
1003 
1004  // Otherwise this is a 1->1+ value mapping.
1005  Value castValue = info->castValue;
1006  assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
1007 
1008  // If the argument is still used, replace it with the generated cast.
1009  if (!origArg.use_empty()) {
1010  rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
1011  castValue, origArg.getType()));
1012  }
1013  }
1014 }
1015 
1016 void BlockTypeConversionRewrite::rollback() {
1017  block->replaceAllUsesWith(origBlock);
1018 }
1019 
1020 LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1021  function_ref<Operation *(Value)> findLiveUser) {
1022  // Process the remapping for each of the original arguments.
1023  for (auto it : llvm::enumerate(origBlock->getArguments())) {
1024  BlockArgument origArg = it.value();
1025  // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
1026  OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
1027  builder.setInsertionPointToStart(block);
1028 
1029  // If the type of this argument changed and the argument is still live, we
1030  // need to materialize a conversion.
1031  if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
1032  continue;
1033  Operation *liveUser = findLiveUser(origArg);
1034  if (!liveUser)
1035  continue;
1036 
1037  Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
1038  bool isDroppedArg = replacementValue == origArg;
1039  if (!isDroppedArg)
1040  builder.setInsertionPointAfterValue(replacementValue);
1041  Value newArg;
1042  if (converter) {
1043  newArg = converter->materializeSourceConversion(
1044  builder, origArg.getLoc(), origArg.getType(),
1045  isDroppedArg ? ValueRange() : ValueRange(replacementValue));
1046  assert((!newArg || newArg.getType() == origArg.getType()) &&
1047  "materialization hook did not provide a value of the expected "
1048  "type");
1049  }
1050  if (!newArg) {
1052  emitError(origArg.getLoc())
1053  << "failed to materialize conversion for block argument #"
1054  << it.index() << " that remained live after conversion, type was "
1055  << origArg.getType();
1056  if (!isDroppedArg)
1057  diag << ", with target type " << replacementValue.getType();
1058  diag.attachNote(liveUser->getLoc())
1059  << "see existing live user here: " << *liveUser;
1060  return failure();
1061  }
1062  rewriterImpl.mapping.map(origArg, newArg);
1063  }
1064  return success();
1065 }
1066 
1067 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1068  Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
1069  if (!repl)
1070  return;
1071 
1072  if (isa<BlockArgument>(repl)) {
1073  rewriter.replaceAllUsesWith(arg, repl);
1074  return;
1075  }
1076 
1077  // If the replacement value is an operation, we check to make sure that we
1078  // don't replace uses that are within the parent operation of the
1079  // replacement value.
1080  Operation *replOp = cast<OpResult>(repl).getOwner();
1081  Block *replBlock = replOp->getBlock();
1082  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1083  Operation *user = operand.getOwner();
1084  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1085  });
1086 }
1087 
1088 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
1089 
1090 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1091  auto *listener =
1092  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1093 
1094  // Compute replacement values.
1095  SmallVector<Value> replacements =
1096  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1097  return rewriterImpl.mapping.lookupOrNull(result, result.getType());
1098  });
1099 
1100  // Notify the listener that the operation is about to be replaced.
1101  if (listener)
1102  listener->notifyOperationReplaced(op, replacements);
1103 
1104  // Replace all uses with the new values.
1105  for (auto [result, newValue] :
1106  llvm::zip_equal(op->getResults(), replacements))
1107  if (newValue)
1108  rewriter.replaceAllUsesWith(result, newValue);
1109 
1110  // The original op will be erased, so remove it from the set of unlegalized
1111  // ops.
1112  if (getConfig().unlegalizedOps)
1113  getConfig().unlegalizedOps->erase(op);
1114 
1115  // Notify the listener that the operation (and its nested operations) was
1116  // erased.
1117  if (listener) {
1119  [&](Operation *op) { listener->notifyOperationErased(op); });
1120  }
1121 
1122  // Do not erase the operation yet. It may still be referenced in `mapping`.
1123  // Just unlink it for now and erase it during cleanup.
1124  op->getBlock()->getOperations().remove(op);
1125 }
1126 
1127 void ReplaceOperationRewrite::rollback() {
1128  for (auto result : op->getResults())
1129  rewriterImpl.mapping.erase(result);
1130 }
1131 
1132 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1133  rewriter.eraseOp(op);
1134 }
1135 
1136 void CreateOperationRewrite::rollback() {
1137  for (Region &region : op->getRegions()) {
1138  while (!region.getBlocks().empty())
1139  region.getBlocks().remove(region.getBlocks().begin());
1140  }
1141  op->dropAllUses();
1142  op->erase();
1143 }
1144 
1145 void UnresolvedMaterializationRewrite::rollback() {
1146  if (getMaterializationKind() == MaterializationKind::Target) {
1147  for (Value input : op->getOperands())
1148  rewriterImpl.mapping.erase(input);
1149  }
1150  op->erase();
1151 }
1152 
1153 void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
1154  rewriter.eraseOp(op);
1155 }
1156 
1158  // Commit all rewrites.
1159  IRRewriter rewriter(context, config.listener);
1160  for (auto &rewrite : rewrites)
1161  rewrite->commit(rewriter);
1162 
1163  // Clean up all rewrites.
1164  SingleEraseRewriter eraseRewriter(context);
1165  for (auto &rewrite : rewrites)
1166  rewrite->cleanup(eraseRewriter);
1167 }
1168 
1169 //===----------------------------------------------------------------------===//
1170 // State Management
1171 
1173  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1174 }
1175 
1177  // Undo any rewrites.
1178  undoRewrites(state.numRewrites);
1179 
1180  // Pop all of the recorded ignored operations that are no longer valid.
1181  while (ignoredOps.size() != state.numIgnoredOperations)
1182  ignoredOps.pop_back();
1183 
1184  while (replacedOps.size() != state.numReplacedOps)
1185  replacedOps.pop_back();
1186 }
1187 
1188 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1189  for (auto &rewrite :
1190  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1191  rewrite->rollback();
1192  rewrites.resize(numRewritesToKeep);
1193 }
1194 
1196  StringRef valueDiagTag, std::optional<Location> inputLoc,
1197  PatternRewriter &rewriter, ValueRange values,
1198  SmallVectorImpl<Value> &remapped) {
1199  remapped.reserve(llvm::size(values));
1200 
1201  SmallVector<Type, 1> legalTypes;
1202  for (const auto &it : llvm::enumerate(values)) {
1203  Value operand = it.value();
1204  Type origType = operand.getType();
1205 
1206  // If a converter was provided, get the desired legal types for this
1207  // operand.
1208  Type desiredType;
1209  if (currentTypeConverter) {
1210  // If there is no legal conversion, fail to match this pattern.
1211  legalTypes.clear();
1212  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1213  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1214  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1215  diag << "unable to convert type for " << valueDiagTag << " #"
1216  << it.index() << ", type was " << origType;
1217  });
1218  return failure();
1219  }
1220  // TODO: There currently isn't any mechanism to do 1->N type conversion
1221  // via the PatternRewriter replacement API, so for now we just ignore it.
1222  if (legalTypes.size() == 1)
1223  desiredType = legalTypes.front();
1224  } else {
1225  // TODO: What we should do here is just set `desiredType` to `origType`
1226  // and then handle the necessary type conversions after the conversion
1227  // process has finished. Unfortunately a lot of patterns currently rely on
1228  // receiving the new operands even if the types change, so we keep the
1229  // original behavior here for now until all of the patterns relying on
1230  // this get updated.
1231  }
1232  Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1233 
1234  // Handle the case where the conversion was 1->1 and the new operand type
1235  // isn't legal.
1236  Type newOperandType = newOperand.getType();
1237  if (currentTypeConverter && desiredType && newOperandType != desiredType) {
1238  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1240  operandLoc, newOperand, desiredType, currentTypeConverter);
1241  mapping.map(mapping.lookupOrDefault(newOperand), castValue);
1242  newOperand = castValue;
1243  }
1244  remapped.push_back(newOperand);
1245  }
1246  return success();
1247 }
1248 
1250  // Check to see if this operation is ignored or was replaced.
1251  return replacedOps.count(op) || ignoredOps.count(op);
1252 }
1253 
1255  // Check to see if this operation was replaced.
1256  return replacedOps.count(op);
1257 }
1258 
1259 //===----------------------------------------------------------------------===//
1260 // Type Conversion
1261 
1263  ConversionPatternRewriter &rewriter, Region *region,
1264  const TypeConverter &converter,
1265  TypeConverter::SignatureConversion *entryConversion) {
1266  regionToConverter[region] = &converter;
1267  if (region->empty())
1268  return nullptr;
1269 
1270  // Convert the arguments of each non-entry block within the region.
1271  for (Block &block :
1272  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1273  // Compute the signature for the block with the provided converter.
1274  std::optional<TypeConverter::SignatureConversion> conversion =
1275  converter.convertBlockSignature(&block);
1276  if (!conversion)
1277  return failure();
1278  // Convert the block with the computed signature.
1279  applySignatureConversion(rewriter, &block, &converter, *conversion);
1280  }
1281 
1282  // Convert the entry block. If an entry signature conversion was provided,
1283  // use that one. Otherwise, compute the signature with the type converter.
1284  if (entryConversion)
1285  return applySignatureConversion(rewriter, &region->front(), &converter,
1286  *entryConversion);
1287  std::optional<TypeConverter::SignatureConversion> conversion =
1288  converter.convertBlockSignature(&region->front());
1289  if (!conversion)
1290  return failure();
1291  return applySignatureConversion(rewriter, &region->front(), &converter,
1292  *conversion);
1293 }
1294 
1296  ConversionPatternRewriter &rewriter, Block *block,
1297  const TypeConverter *converter,
1298  TypeConverter::SignatureConversion &signatureConversion) {
1299  OpBuilder::InsertionGuard g(rewriter);
1300 
1301  // If no arguments are being changed or added, there is nothing to do.
1302  unsigned origArgCount = block->getNumArguments();
1303  auto convertedTypes = signatureConversion.getConvertedTypes();
1304  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1305  return block;
1306 
1307  // Compute the locations of all block arguments in the new block.
1308  SmallVector<Location> newLocs(convertedTypes.size(),
1309  rewriter.getUnknownLoc());
1310  for (unsigned i = 0; i < origArgCount; ++i) {
1311  auto inputMap = signatureConversion.getInputMapping(i);
1312  if (!inputMap || inputMap->replacementValue)
1313  continue;
1314  Location origLoc = block->getArgument(i).getLoc();
1315  for (unsigned j = 0; j < inputMap->size; ++j)
1316  newLocs[inputMap->inputNo + j] = origLoc;
1317  }
1318 
1319  // Insert a new block with the converted block argument types and move all ops
1320  // from the old block to the new block.
1321  Block *newBlock =
1322  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1323  convertedTypes, newLocs);
1324 
1325  // If a listener is attached to the dialect conversion, ops cannot be moved
1326  // to the destination block in bulk ("fast path"). This is because at the time
1327  // the notifications are sent, it is unknown which ops were moved. Instead,
1328  // ops should be moved one-by-one ("slow path"), so that a separate
1329  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1330  // a bit more efficient, so we try to do that when possible.
1331  bool fastPath = !config.listener;
1332  if (fastPath) {
1333  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1334  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1335  } else {
1336  while (!block->empty())
1337  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1338  }
1339 
1340  // Replace all uses of the old block with the new block.
1341  block->replaceAllUsesWith(newBlock);
1342 
1343  // Remap each of the original arguments as determined by the signature
1344  // conversion.
1346  argInfo.resize(origArgCount);
1347 
1348  for (unsigned i = 0; i != origArgCount; ++i) {
1349  auto inputMap = signatureConversion.getInputMapping(i);
1350  if (!inputMap)
1351  continue;
1352  BlockArgument origArg = block->getArgument(i);
1353 
1354  // If inputMap->replacementValue is not nullptr, then the argument is
1355  // dropped and a replacement value is provided to be the remappedValue.
1356  if (inputMap->replacementValue) {
1357  assert(inputMap->size == 0 &&
1358  "invalid to provide a replacement value when the argument isn't "
1359  "dropped");
1360  mapping.map(origArg, inputMap->replacementValue);
1361  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1362  continue;
1363  }
1364 
1365  // Otherwise, this is a 1->1+ mapping.
1366  auto replArgs =
1367  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1368  Value newArg;
1369 
1370  // If this is a 1->1 mapping and the types of new and replacement arguments
1371  // match (i.e. it's an identity map), then the argument is mapped to its
1372  // original type.
1373  // FIXME: We simply pass through the replacement argument if there wasn't a
1374  // converter, which isn't great as it allows implicit type conversions to
1375  // appear. We should properly restructure this code to handle cases where a
1376  // converter isn't provided and also to properly handle the case where an
1377  // argument materialization is actually a temporary source materialization
1378  // (e.g. in the case of 1->N).
1379  if (replArgs.size() == 1 &&
1380  (!converter || replArgs[0].getType() == origArg.getType())) {
1381  newArg = replArgs.front();
1382  mapping.map(origArg, newArg);
1383  } else {
1384  // Build argument materialization: new block arguments -> old block
1385  // argument type.
1387  newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
1388  mapping.map(origArg, argMat);
1389 
1390  // Build target materialization: old block argument type -> legal type.
1391  // Note: This function returns an "empty" type if no valid conversion to
1392  // a legal type exists. In that case, we continue the conversion with the
1393  // original block argument type.
1394  Type legalOutputType = converter->convertType(origArg.getType());
1395  if (legalOutputType && legalOutputType != origArg.getType()) {
1397  origArg.getLoc(), argMat, legalOutputType, converter);
1398  mapping.map(argMat, newArg);
1399  } else {
1400  newArg = argMat;
1401  }
1402  }
1403 
1404  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1405  argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
1406  }
1407 
1408  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1409  converter);
1410 
1411  // Erase the old block. (It is just unlinked for now and will be erased during
1412  // cleanup.)
1413  rewriter.eraseBlock(block);
1414 
1415  return newBlock;
1416 }
1417 
1418 //===----------------------------------------------------------------------===//
1419 // Materializations
1420 //===----------------------------------------------------------------------===//
1421 
1422 /// Build an unresolved materialization operation given an output type and set
1423 /// of input operands.
1425  MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1426  Location loc, ValueRange inputs, Type outputType,
1427  const TypeConverter *converter) {
1428  // Avoid materializing an unnecessary cast.
1429  if (inputs.size() == 1 && inputs.front().getType() == outputType)
1430  return inputs.front();
1431 
1432  // Create an unresolved materialization. We use a new OpBuilder to avoid
1433  // tracking the materialization like we do for other operations.
1434  OpBuilder builder(insertBlock, insertPt);
1435  auto convertOp =
1436  builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1437  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1438  return convertOp.getResult(0);
1439 }
1441  Block *block, Location loc, ValueRange inputs, Type outputType,
1442  const TypeConverter *converter) {
1444  block->begin(), loc, inputs, outputType,
1445  converter);
1446 }
1448  Location loc, Value input, Type outputType,
1449  const TypeConverter *converter) {
1450  Block *insertBlock = input.getParentBlock();
1451  Block::iterator insertPt = insertBlock->begin();
1452  if (OpResult inputRes = dyn_cast<OpResult>(input))
1453  insertPt = ++inputRes.getOwner()->getIterator();
1454 
1455  return buildUnresolvedMaterialization(MaterializationKind::Target,
1456  insertBlock, insertPt, loc, input,
1457  outputType, converter);
1458 }
1459 
1460 //===----------------------------------------------------------------------===//
1461 // Rewriter Notification Hooks
1462 
1464  Operation *op, OpBuilder::InsertPoint previous) {
1465  LLVM_DEBUG({
1466  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1467  << ")\n";
1468  });
1469  assert(!wasOpReplaced(op->getParentOp()) &&
1470  "attempting to insert into a block within a replaced/erased op");
1471 
1472  if (!previous.isSet()) {
1473  // This is a newly created op.
1474  appendRewrite<CreateOperationRewrite>(op);
1475  return;
1476  }
1477  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1478  ? nullptr
1479  : &*previous.getPoint();
1480  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1481 }
1482 
1484  ValueRange newValues) {
1485  assert(newValues.size() == op->getNumResults());
1486  assert(!ignoredOps.contains(op) && "operation was already replaced");
1487 
1488  // Track if any of the results changed, e.g. erased and replaced with null.
1489  bool resultChanged = false;
1490 
1491  // Create mappings for each of the new result values.
1492  for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
1493  if (!newValue) {
1494  resultChanged = true;
1495  continue;
1496  }
1497  // Remap, and check for any result type changes.
1498  mapping.map(result, newValue);
1499  resultChanged |= (newValue.getType() != result.getType());
1500  }
1501 
1502  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
1503  resultChanged);
1504 
1505  // Mark this operation and all nested ops as replaced.
1506  op->walk([&](Operation *op) { replacedOps.insert(op); });
1507 }
1508 
1510  Region *region = block->getParent();
1511  Block *origNextBlock = block->getNextNode();
1512  appendRewrite<EraseBlockRewrite>(block, region, origNextBlock);
1513 }
1514 
1516  Block *block, Region *previous, Region::iterator previousIt) {
1517  assert(!wasOpReplaced(block->getParentOp()) &&
1518  "attempting to insert into a region within a replaced/erased op");
1519  LLVM_DEBUG(
1520  {
1521  Operation *parent = block->getParentOp();
1522  if (parent) {
1523  logger.startLine() << "** Insert Block into : '" << parent->getName()
1524  << "'(" << parent << ")\n";
1525  } else {
1526  logger.startLine()
1527  << "** Insert Block into detached Region (nullptr parent op)'";
1528  }
1529  });
1530 
1531  if (!previous) {
1532  // This is a newly created block.
1533  appendRewrite<CreateBlockRewrite>(block);
1534  return;
1535  }
1536  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1537  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1538 }
1539 
1541  Block *block, Block *srcBlock, Block::iterator before) {
1542  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1543 }
1544 
1546  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1547  LLVM_DEBUG({
1549  reasonCallback(diag);
1550  logger.startLine() << "** Failure : " << diag.str() << "\n";
1551  if (config.notifyCallback)
1553  });
1554 }
1555 
1556 //===----------------------------------------------------------------------===//
1557 // ConversionPatternRewriter
1558 //===----------------------------------------------------------------------===//
1559 
1560 ConversionPatternRewriter::ConversionPatternRewriter(
1561  MLIRContext *ctx, const ConversionConfig &config)
1562  : PatternRewriter(ctx),
1563  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1564  setListener(impl.get());
1565 }
1566 
1568 
1570  assert(op && newOp && "expected non-null op");
1571  replaceOp(op, newOp->getResults());
1572 }
1573 
1575  assert(op->getNumResults() == newValues.size() &&
1576  "incorrect # of replacement values");
1577  LLVM_DEBUG({
1578  impl->logger.startLine()
1579  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1580  });
1581  impl->notifyOpReplaced(op, newValues);
1582 }
1583 
1585  LLVM_DEBUG({
1586  impl->logger.startLine()
1587  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1588  });
1589  SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1590  impl->notifyOpReplaced(op, nullRepls);
1591 }
1592 
1594  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1595  "attempting to erase a block within a replaced/erased op");
1596 
1597  // Mark all ops for erasure.
1598  for (Operation &op : *block)
1599  eraseOp(&op);
1600 
1601  // Unlink the block from its parent region. The block is kept in the rewrite
1602  // object and will be actually destroyed when rewrites are applied. This
1603  // allows us to keep the operations in the block live and undo the removal by
1604  // re-inserting the block.
1605  impl->notifyBlockIsBeingErased(block);
1606  block->getParent()->getBlocks().remove(block);
1607 }
1608 
1610  Block *block, TypeConverter::SignatureConversion &conversion,
1611  const TypeConverter *converter) {
1612  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1613  "attempting to apply a signature conversion to a block within a "
1614  "replaced/erased op");
1615  return impl->applySignatureConversion(*this, block, converter, conversion);
1616 }
1617 
1619  Region *region, const TypeConverter &converter,
1620  TypeConverter::SignatureConversion *entryConversion) {
1621  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1622  "attempting to apply a signature conversion to a block within a "
1623  "replaced/erased op");
1624  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1625 }
1626 
1628  Value to) {
1629  LLVM_DEBUG({
1630  Operation *parentOp = from.getOwner()->getParentOp();
1631  impl->logger.startLine() << "** Replace Argument : '" << from
1632  << "'(in region of '" << parentOp->getName()
1633  << "'(" << from.getOwner()->getParentOp() << ")\n";
1634  });
1635  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
1636  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1637 }
1638 
1640  SmallVector<Value> remappedValues;
1641  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1642  remappedValues)))
1643  return nullptr;
1644  return remappedValues.front();
1645 }
1646 
1647 LogicalResult
1649  SmallVectorImpl<Value> &results) {
1650  if (keys.empty())
1651  return success();
1652  return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1653  results);
1654 }
1655 
1657  Block::iterator before,
1658  ValueRange argValues) {
1659 #ifndef NDEBUG
1660  assert(argValues.size() == source->getNumArguments() &&
1661  "incorrect # of argument replacement values");
1662  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1663  "attempting to inline a block from a replaced/erased op");
1664  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1665  "attempting to inline a block into a replaced/erased op");
1666  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1667  // The source block will be deleted, so it should not have any users (i.e.,
1668  // there should be no predecessors).
1669  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1670  "expected 'source' to have no predecessors");
1671 #endif // NDEBUG
1672 
1673  // If a listener is attached to the dialect conversion, ops cannot be moved
1674  // to the destination block in bulk ("fast path"). This is because at the time
1675  // the notifications are sent, it is unknown which ops were moved. Instead,
1676  // ops should be moved one-by-one ("slow path"), so that a separate
1677  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1678  // a bit more efficient, so we try to do that when possible.
1679  bool fastPath = !impl->config.listener;
1680 
1681  if (fastPath)
1682  impl->notifyBlockBeingInlined(dest, source, before);
1683 
1684  // Replace all uses of block arguments.
1685  for (auto it : llvm::zip(source->getArguments(), argValues))
1686  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1687 
1688  if (fastPath) {
1689  // Move all ops at once.
1690  dest->getOperations().splice(before, source->getOperations());
1691  } else {
1692  // Move op by op.
1693  while (!source->empty())
1694  moveOpBefore(&source->front(), dest, before);
1695  }
1696 
1697  // Erase the source block.
1698  eraseBlock(source);
1699 }
1700 
1702  assert(!impl->wasOpReplaced(op) &&
1703  "attempting to modify a replaced/erased op");
1704 #ifndef NDEBUG
1705  impl->pendingRootUpdates.insert(op);
1706 #endif
1707  impl->appendRewrite<ModifyOperationRewrite>(op);
1708 }
1709 
1711  assert(!impl->wasOpReplaced(op) &&
1712  "attempting to modify a replaced/erased op");
1714  // There is nothing to do here, we only need to track the operation at the
1715  // start of the update.
1716 #ifndef NDEBUG
1717  assert(impl->pendingRootUpdates.erase(op) &&
1718  "operation did not have a pending in-place update");
1719 #endif
1720 }
1721 
1723 #ifndef NDEBUG
1724  assert(impl->pendingRootUpdates.erase(op) &&
1725  "operation did not have a pending in-place update");
1726 #endif
1727  // Erase the last update for this operation.
1728  auto it = llvm::find_if(
1729  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1730  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1731  return modifyRewrite && modifyRewrite->getOperation() == op;
1732  });
1733  assert(it != impl->rewrites.rend() && "no root update started on op");
1734  (*it)->rollback();
1735  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1736  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1737 }
1738 
1740  return *impl;
1741 }
1742 
1743 //===----------------------------------------------------------------------===//
1744 // ConversionPattern
1745 //===----------------------------------------------------------------------===//
1746 
1747 LogicalResult
1749  PatternRewriter &rewriter) const {
1750  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1751  auto &rewriterImpl = dialectRewriter.getImpl();
1752 
1753  // Track the current conversion pattern type converter in the rewriter.
1754  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1755  getTypeConverter());
1756 
1757  // Remap the operands of the operation.
1758  SmallVector<Value, 4> operands;
1759  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1760  op->getOperands(), operands))) {
1761  return failure();
1762  }
1763  return matchAndRewrite(op, operands, dialectRewriter);
1764 }
1765 
1766 //===----------------------------------------------------------------------===//
1767 // OperationLegalizer
1768 //===----------------------------------------------------------------------===//
1769 
1770 namespace {
1771 /// A set of rewrite patterns that can be used to legalize a given operation.
1772 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1773 
1774 /// This class defines a recursive operation legalizer.
1775 class OperationLegalizer {
1776 public:
1777  using LegalizationAction = ConversionTarget::LegalizationAction;
1778 
1779  OperationLegalizer(const ConversionTarget &targetInfo,
1780  const FrozenRewritePatternSet &patterns,
1781  const ConversionConfig &config);
1782 
1783  /// Returns true if the given operation is known to be illegal on the target.
1784  bool isIllegal(Operation *op) const;
1785 
1786  /// Attempt to legalize the given operation. Returns success if the operation
1787  /// was legalized, failure otherwise.
1788  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1789 
1790  /// Returns the conversion target in use by the legalizer.
1791  const ConversionTarget &getTarget() { return target; }
1792 
1793 private:
1794  /// Attempt to legalize the given operation by folding it.
1795  LogicalResult legalizeWithFold(Operation *op,
1796  ConversionPatternRewriter &rewriter);
1797 
1798  /// Attempt to legalize the given operation by applying a pattern. Returns
1799  /// success if the operation was legalized, failure otherwise.
1800  LogicalResult legalizeWithPattern(Operation *op,
1801  ConversionPatternRewriter &rewriter);
1802 
1803  /// Return true if the given pattern may be applied to the given operation,
1804  /// false otherwise.
1805  bool canApplyPattern(Operation *op, const Pattern &pattern,
1806  ConversionPatternRewriter &rewriter);
1807 
1808  /// Legalize the resultant IR after successfully applying the given pattern.
1809  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1810  ConversionPatternRewriter &rewriter,
1811  RewriterState &curState);
1812 
1813  /// Legalizes the actions registered during the execution of a pattern.
1814  LogicalResult
1815  legalizePatternBlockRewrites(Operation *op,
1816  ConversionPatternRewriter &rewriter,
1818  RewriterState &state, RewriterState &newState);
1819  LogicalResult legalizePatternCreatedOperations(
1821  RewriterState &state, RewriterState &newState);
1822  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1824  RewriterState &state,
1825  RewriterState &newState);
1826 
1827  //===--------------------------------------------------------------------===//
1828  // Cost Model
1829  //===--------------------------------------------------------------------===//
1830 
1831  /// Build an optimistic legalization graph given the provided patterns. This
1832  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1833  /// patterns for operations that are not directly legal, but may be
1834  /// transitively legal for the current target given the provided patterns.
1835  void buildLegalizationGraph(
1836  LegalizationPatterns &anyOpLegalizerPatterns,
1838 
1839  /// Compute the benefit of each node within the computed legalization graph.
1840  /// This orders the patterns within 'legalizerPatterns' based upon two
1841  /// criteria:
1842  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1843  /// represent the more direct mapping to the target.
1844  /// 2) When comparing patterns with the same legalization depth, prefer the
1845  /// pattern with the highest PatternBenefit. This allows for users to
1846  /// prefer specific legalizations over others.
1847  void computeLegalizationGraphBenefit(
1848  LegalizationPatterns &anyOpLegalizerPatterns,
1850 
1851  /// Compute the legalization depth when legalizing an operation of the given
1852  /// type.
1853  unsigned computeOpLegalizationDepth(
1854  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1856 
1857  /// Apply the conversion cost model to the given set of patterns, and return
1858  /// the smallest legalization depth of any of the patterns. See
1859  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1860  unsigned applyCostModelToPatterns(
1861  LegalizationPatterns &patterns,
1862  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1864 
1865  /// The current set of patterns that have been applied.
1866  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1867 
1868  /// The legalization information provided by the target.
1869  const ConversionTarget &target;
1870 
1871  /// The pattern applicator to use for conversions.
1872  PatternApplicator applicator;
1873 
1874  /// Dialect conversion configuration.
1875  const ConversionConfig &config;
1876 };
1877 } // namespace
1878 
1879 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1880  const FrozenRewritePatternSet &patterns,
1881  const ConversionConfig &config)
1882  : target(targetInfo), applicator(patterns), config(config) {
1883  // The set of patterns that can be applied to illegal operations to transform
1884  // them into legal ones.
1886  LegalizationPatterns anyOpLegalizerPatterns;
1887 
1888  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1889  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1890 }
1891 
1892 bool OperationLegalizer::isIllegal(Operation *op) const {
1893  return target.isIllegal(op);
1894 }
1895 
1896 LogicalResult
1897 OperationLegalizer::legalize(Operation *op,
1898  ConversionPatternRewriter &rewriter) {
1899 #ifndef NDEBUG
1900  const char *logLineComment =
1901  "//===-------------------------------------------===//\n";
1902 
1903  auto &logger = rewriter.getImpl().logger;
1904 #endif
1905  LLVM_DEBUG({
1906  logger.getOStream() << "\n";
1907  logger.startLine() << logLineComment;
1908  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1909  << op << ") {\n";
1910  logger.indent();
1911 
1912  // If the operation has no regions, just print it here.
1913  if (op->getNumRegions() == 0) {
1914  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1915  logger.getOStream() << "\n\n";
1916  }
1917  });
1918 
1919  // Check if this operation is legal on the target.
1920  if (auto legalityInfo = target.isLegal(op)) {
1921  LLVM_DEBUG({
1922  logSuccess(
1923  logger, "operation marked legal by the target{0}",
1924  legalityInfo->isRecursivelyLegal
1925  ? "; NOTE: operation is recursively legal; skipping internals"
1926  : "");
1927  logger.startLine() << logLineComment;
1928  });
1929 
1930  // If this operation is recursively legal, mark its children as ignored so
1931  // that we don't consider them for legalization.
1932  if (legalityInfo->isRecursivelyLegal) {
1933  op->walk([&](Operation *nested) {
1934  if (op != nested)
1935  rewriter.getImpl().ignoredOps.insert(nested);
1936  });
1937  }
1938 
1939  return success();
1940  }
1941 
1942  // Check to see if the operation is ignored and doesn't need to be converted.
1943  if (rewriter.getImpl().isOpIgnored(op)) {
1944  LLVM_DEBUG({
1945  logSuccess(logger, "operation marked 'ignored' during conversion");
1946  logger.startLine() << logLineComment;
1947  });
1948  return success();
1949  }
1950 
1951  // If the operation isn't legal, try to fold it in-place.
1952  // TODO: Should we always try to do this, even if the op is
1953  // already legal?
1954  if (succeeded(legalizeWithFold(op, rewriter))) {
1955  LLVM_DEBUG({
1956  logSuccess(logger, "operation was folded");
1957  logger.startLine() << logLineComment;
1958  });
1959  return success();
1960  }
1961 
1962  // Otherwise, we need to apply a legalization pattern to this operation.
1963  if (succeeded(legalizeWithPattern(op, rewriter))) {
1964  LLVM_DEBUG({
1965  logSuccess(logger, "");
1966  logger.startLine() << logLineComment;
1967  });
1968  return success();
1969  }
1970 
1971  LLVM_DEBUG({
1972  logFailure(logger, "no matched legalization pattern");
1973  logger.startLine() << logLineComment;
1974  });
1975  return failure();
1976 }
1977 
1978 LogicalResult
1979 OperationLegalizer::legalizeWithFold(Operation *op,
1980  ConversionPatternRewriter &rewriter) {
1981  auto &rewriterImpl = rewriter.getImpl();
1982  RewriterState curState = rewriterImpl.getCurrentState();
1983 
1984  LLVM_DEBUG({
1985  rewriterImpl.logger.startLine() << "* Fold {\n";
1986  rewriterImpl.logger.indent();
1987  });
1988 
1989  // Try to fold the operation.
1990  SmallVector<Value, 2> replacementValues;
1991  rewriter.setInsertionPoint(op);
1992  if (failed(rewriter.tryFold(op, replacementValues))) {
1993  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1994  return failure();
1995  }
1996  // An empty list of replacement values indicates that the fold was in-place.
1997  // As the operation changed, a new legalization needs to be attempted.
1998  if (replacementValues.empty())
1999  return legalize(op, rewriter);
2000 
2001  // Insert a replacement for 'op' with the folded replacement values.
2002  rewriter.replaceOp(op, replacementValues);
2003 
2004  // Recursively legalize any new constant operations.
2005  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
2006  i != e; ++i) {
2007  auto *createOp =
2008  dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
2009  if (!createOp)
2010  continue;
2011  if (failed(legalize(createOp->getOperation(), rewriter))) {
2012  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2013  "failed to legalize generated constant '{0}'",
2014  createOp->getOperation()->getName()));
2015  rewriterImpl.resetState(curState);
2016  return failure();
2017  }
2018  }
2019 
2020  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2021  return success();
2022 }
2023 
2024 LogicalResult
2025 OperationLegalizer::legalizeWithPattern(Operation *op,
2026  ConversionPatternRewriter &rewriter) {
2027  auto &rewriterImpl = rewriter.getImpl();
2028 
2029  // Functor that returns if the given pattern may be applied.
2030  auto canApply = [&](const Pattern &pattern) {
2031  bool canApply = canApplyPattern(op, pattern, rewriter);
2032  if (canApply && config.listener)
2033  config.listener->notifyPatternBegin(pattern, op);
2034  return canApply;
2035  };
2036 
2037  // Functor that cleans up the rewriter state after a pattern failed to match.
2038  RewriterState curState = rewriterImpl.getCurrentState();
2039  auto onFailure = [&](const Pattern &pattern) {
2040  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2041  LLVM_DEBUG({
2042  logFailure(rewriterImpl.logger, "pattern failed to match");
2043  if (rewriterImpl.config.notifyCallback) {
2045  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2046  << "\" on op:\n"
2047  << *op;
2048  rewriterImpl.config.notifyCallback(diag);
2049  }
2050  });
2051  if (config.listener)
2052  config.listener->notifyPatternEnd(pattern, failure());
2053  rewriterImpl.resetState(curState);
2054  appliedPatterns.erase(&pattern);
2055  };
2056 
2057  // Functor that performs additional legalization when a pattern is
2058  // successfully applied.
2059  auto onSuccess = [&](const Pattern &pattern) {
2060  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2061  auto result = legalizePatternResult(op, pattern, rewriter, curState);
2062  appliedPatterns.erase(&pattern);
2063  if (failed(result))
2064  rewriterImpl.resetState(curState);
2065  if (config.listener)
2066  config.listener->notifyPatternEnd(pattern, result);
2067  return result;
2068  };
2069 
2070  // Try to match and rewrite a pattern on this operation.
2071  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2072  onSuccess);
2073 }
2074 
2075 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2076  ConversionPatternRewriter &rewriter) {
2077  LLVM_DEBUG({
2078  auto &os = rewriter.getImpl().logger;
2079  os.getOStream() << "\n";
2080  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2081  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2082  os.getOStream() << ")' {\n";
2083  os.indent();
2084  });
2085 
2086  // Ensure that we don't cycle by not allowing the same pattern to be
2087  // applied twice in the same recursion stack if it is not known to be safe.
2088  if (!pattern.hasBoundedRewriteRecursion() &&
2089  !appliedPatterns.insert(&pattern).second) {
2090  LLVM_DEBUG(
2091  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2092  return false;
2093  }
2094  return true;
2095 }
2096 
2097 LogicalResult
2098 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2099  ConversionPatternRewriter &rewriter,
2100  RewriterState &curState) {
2101  auto &impl = rewriter.getImpl();
2102 
2103 #ifndef NDEBUG
2104  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2105  // Check that the root was either replaced or updated in place.
2106  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2107  auto replacedRoot = [&] {
2108  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2109  };
2110  auto updatedRootInPlace = [&] {
2111  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2112  };
2113  assert((replacedRoot() || updatedRootInPlace()) &&
2114  "expected pattern to replace the root operation");
2115 #endif // NDEBUG
2116 
2117  // Legalize each of the actions registered during application.
2118  RewriterState newState = impl.getCurrentState();
2119  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2120  newState)) ||
2121  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2122  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2123  newState))) {
2124  return failure();
2125  }
2126 
2127  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2128  return success();
2129 }
2130 
2131 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2132  Operation *op, ConversionPatternRewriter &rewriter,
2133  ConversionPatternRewriterImpl &impl, RewriterState &state,
2134  RewriterState &newState) {
2135  SmallPtrSet<Operation *, 16> operationsToIgnore;
2136 
2137  // If the pattern moved or created any blocks, make sure the types of block
2138  // arguments get legalized.
2139  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2140  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2141  if (!rewrite)
2142  continue;
2143  Block *block = rewrite->getBlock();
2144  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2145  ReplaceBlockArgRewrite>(rewrite))
2146  continue;
2147  // Only check blocks outside of the current operation.
2148  Operation *parentOp = block->getParentOp();
2149  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2150  continue;
2151 
2152  // If the region of the block has a type converter, try to convert the block
2153  // directly.
2154  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2155  std::optional<TypeConverter::SignatureConversion> conversion =
2156  converter->convertBlockSignature(block);
2157  if (!conversion) {
2158  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2159  "block"));
2160  return failure();
2161  }
2162  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2163  continue;
2164  }
2165 
2166  // Otherwise, check that this operation isn't one generated by this pattern.
2167  // This is because we will attempt to legalize the parent operation, and
2168  // blocks in regions created by this pattern will already be legalized later
2169  // on. If we haven't built the set yet, build it now.
2170  if (operationsToIgnore.empty()) {
2171  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2172  ++i) {
2173  auto *createOp =
2174  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2175  if (!createOp)
2176  continue;
2177  operationsToIgnore.insert(createOp->getOperation());
2178  }
2179  }
2180 
2181  // If this operation should be considered for re-legalization, try it.
2182  if (operationsToIgnore.insert(parentOp).second &&
2183  failed(legalize(parentOp, rewriter))) {
2184  LLVM_DEBUG(logFailure(impl.logger,
2185  "operation '{0}'({1}) became illegal after rewrite",
2186  parentOp->getName(), parentOp));
2187  return failure();
2188  }
2189  }
2190  return success();
2191 }
2192 
2193 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2195  RewriterState &state, RewriterState &newState) {
2196  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2197  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2198  if (!createOp)
2199  continue;
2200  Operation *op = createOp->getOperation();
2201  if (failed(legalize(op, rewriter))) {
2202  LLVM_DEBUG(logFailure(impl.logger,
2203  "failed to legalize generated operation '{0}'({1})",
2204  op->getName(), op));
2205  return failure();
2206  }
2207  }
2208  return success();
2209 }
2210 
2211 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2213  RewriterState &state, RewriterState &newState) {
2214  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2215  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2216  if (!rewrite)
2217  continue;
2218  Operation *op = rewrite->getOperation();
2219  if (failed(legalize(op, rewriter))) {
2220  LLVM_DEBUG(logFailure(
2221  impl.logger, "failed to legalize operation updated in-place '{0}'",
2222  op->getName()));
2223  return failure();
2224  }
2225  }
2226  return success();
2227 }
2228 
2229 //===----------------------------------------------------------------------===//
2230 // Cost Model
2231 
2232 void OperationLegalizer::buildLegalizationGraph(
2233  LegalizationPatterns &anyOpLegalizerPatterns,
2234  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2235  // A mapping between an operation and a set of operations that can be used to
2236  // generate it.
2238  // A mapping between an operation and any currently invalid patterns it has.
2240  // A worklist of patterns to consider for legality.
2241  SetVector<const Pattern *> patternWorklist;
2242 
2243  // Build the mapping from operations to the parent ops that may generate them.
2244  applicator.walkAllPatterns([&](const Pattern &pattern) {
2245  std::optional<OperationName> root = pattern.getRootKind();
2246 
2247  // If the pattern has no specific root, we can't analyze the relationship
2248  // between the root op and generated operations. Given that, add all such
2249  // patterns to the legalization set.
2250  if (!root) {
2251  anyOpLegalizerPatterns.push_back(&pattern);
2252  return;
2253  }
2254 
2255  // Skip operations that are always known to be legal.
2256  if (target.getOpAction(*root) == LegalizationAction::Legal)
2257  return;
2258 
2259  // Add this pattern to the invalid set for the root op and record this root
2260  // as a parent for any generated operations.
2261  invalidPatterns[*root].insert(&pattern);
2262  for (auto op : pattern.getGeneratedOps())
2263  parentOps[op].insert(*root);
2264 
2265  // Add this pattern to the worklist.
2266  patternWorklist.insert(&pattern);
2267  });
2268 
2269  // If there are any patterns that don't have a specific root kind, we can't
2270  // make direct assumptions about what operations will never be legalized.
2271  // Note: Technically we could, but it would require an analysis that may
2272  // recurse into itself. It would be better to perform this kind of filtering
2273  // at a higher level than here anyways.
2274  if (!anyOpLegalizerPatterns.empty()) {
2275  for (const Pattern *pattern : patternWorklist)
2276  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2277  return;
2278  }
2279 
2280  while (!patternWorklist.empty()) {
2281  auto *pattern = patternWorklist.pop_back_val();
2282 
2283  // Check to see if any of the generated operations are invalid.
2284  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2285  std::optional<LegalizationAction> action = target.getOpAction(op);
2286  return !legalizerPatterns.count(op) &&
2287  (!action || action == LegalizationAction::Illegal);
2288  }))
2289  continue;
2290 
2291  // Otherwise, if all of the generated operation are valid, this op is now
2292  // legal so add all of the child patterns to the worklist.
2293  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2294  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2295 
2296  // Add any invalid patterns of the parent operations to see if they have now
2297  // become legal.
2298  for (auto op : parentOps[*pattern->getRootKind()])
2299  patternWorklist.set_union(invalidPatterns[op]);
2300  }
2301 }
2302 
2303 void OperationLegalizer::computeLegalizationGraphBenefit(
2304  LegalizationPatterns &anyOpLegalizerPatterns,
2305  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2306  // The smallest pattern depth, when legalizing an operation.
2307  DenseMap<OperationName, unsigned> minOpPatternDepth;
2308 
2309  // For each operation that is transitively legal, compute a cost for it.
2310  for (auto &opIt : legalizerPatterns)
2311  if (!minOpPatternDepth.count(opIt.first))
2312  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2313  legalizerPatterns);
2314 
2315  // Apply the cost model to the patterns that can match any operation. Those
2316  // with a specific operation type are already resolved when computing the op
2317  // legalization depth.
2318  if (!anyOpLegalizerPatterns.empty())
2319  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2320  legalizerPatterns);
2321 
2322  // Apply a cost model to the pattern applicator. We order patterns first by
2323  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2324  // decreasing benefit.
2325  applicator.applyCostModel([&](const Pattern &pattern) {
2326  ArrayRef<const Pattern *> orderedPatternList;
2327  if (std::optional<OperationName> rootName = pattern.getRootKind())
2328  orderedPatternList = legalizerPatterns[*rootName];
2329  else
2330  orderedPatternList = anyOpLegalizerPatterns;
2331 
2332  // If the pattern is not found, then it was removed and cannot be matched.
2333  auto *it = llvm::find(orderedPatternList, &pattern);
2334  if (it == orderedPatternList.end())
2336 
2337  // Patterns found earlier in the list have higher benefit.
2338  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2339  });
2340 }
2341 
2342 unsigned OperationLegalizer::computeOpLegalizationDepth(
2343  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2344  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2345  // Check for existing depth.
2346  auto depthIt = minOpPatternDepth.find(op);
2347  if (depthIt != minOpPatternDepth.end())
2348  return depthIt->second;
2349 
2350  // If a mapping for this operation does not exist, then this operation
2351  // is always legal. Return 0 as the depth for a directly legal operation.
2352  auto opPatternsIt = legalizerPatterns.find(op);
2353  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2354  return 0u;
2355 
2356  // Record this initial depth in case we encounter this op again when
2357  // recursively computing the depth.
2358  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2359 
2360  // Apply the cost model to the operation patterns, and update the minimum
2361  // depth.
2362  unsigned minDepth = applyCostModelToPatterns(
2363  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2364  minOpPatternDepth[op] = minDepth;
2365  return minDepth;
2366 }
2367 
2368 unsigned OperationLegalizer::applyCostModelToPatterns(
2369  LegalizationPatterns &patterns,
2370  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2371  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2372  unsigned minDepth = std::numeric_limits<unsigned>::max();
2373 
2374  // Compute the depth for each pattern within the set.
2376  patternsByDepth.reserve(patterns.size());
2377  for (const Pattern *pattern : patterns) {
2378  unsigned depth = 1;
2379  for (auto generatedOp : pattern->getGeneratedOps()) {
2380  unsigned generatedOpDepth = computeOpLegalizationDepth(
2381  generatedOp, minOpPatternDepth, legalizerPatterns);
2382  depth = std::max(depth, generatedOpDepth + 1);
2383  }
2384  patternsByDepth.emplace_back(pattern, depth);
2385 
2386  // Update the minimum depth of the pattern list.
2387  minDepth = std::min(minDepth, depth);
2388  }
2389 
2390  // If the operation only has one legalization pattern, there is no need to
2391  // sort them.
2392  if (patternsByDepth.size() == 1)
2393  return minDepth;
2394 
2395  // Sort the patterns by those likely to be the most beneficial.
2396  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2397  [](const std::pair<const Pattern *, unsigned> &lhs,
2398  const std::pair<const Pattern *, unsigned> &rhs) {
2399  // First sort by the smaller pattern legalization
2400  // depth.
2401  if (lhs.second != rhs.second)
2402  return lhs.second < rhs.second;
2403 
2404  // Then sort by the larger pattern benefit.
2405  auto lhsBenefit = lhs.first->getBenefit();
2406  auto rhsBenefit = rhs.first->getBenefit();
2407  return lhsBenefit > rhsBenefit;
2408  });
2409 
2410  // Update the legalization pattern to use the new sorted list.
2411  patterns.clear();
2412  for (auto &patternIt : patternsByDepth)
2413  patterns.push_back(patternIt.first);
2414  return minDepth;
2415 }
2416 
2417 //===----------------------------------------------------------------------===//
2418 // OperationConverter
2419 //===----------------------------------------------------------------------===//
2420 namespace {
2421 enum OpConversionMode {
2422  /// In this mode, the conversion will ignore failed conversions to allow
2423  /// illegal operations to co-exist in the IR.
2424  Partial,
2425 
2426  /// In this mode, all operations must be legal for the given target for the
2427  /// conversion to succeed.
2428  Full,
2429 
2430  /// In this mode, operations are analyzed for legality. No actual rewrites are
2431  /// applied to the operations on success.
2432  Analysis,
2433 };
2434 } // namespace
2435 
2436 namespace mlir {
2437 // This class converts operations to a given conversion target via a set of
2438 // rewrite patterns. The conversion behaves differently depending on the
2439 // conversion mode.
2441  explicit OperationConverter(const ConversionTarget &target,
2442  const FrozenRewritePatternSet &patterns,
2443  const ConversionConfig &config,
2444  OpConversionMode mode)
2445  : config(config), opLegalizer(target, patterns, this->config),
2446  mode(mode) {}
2447 
2448  /// Converts the given operations to the conversion target.
2449  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2450 
2451 private:
2452  /// Converts an operation with the given rewriter.
2453  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2454 
2455  /// This method is called after the conversion process to legalize any
2456  /// remaining artifacts and complete the conversion.
2457  LogicalResult finalize(ConversionPatternRewriter &rewriter);
2458 
2459  /// Legalize the types of converted block arguments.
2460  LogicalResult
2461  legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2462  ConversionPatternRewriterImpl &rewriterImpl);
2463 
2464  /// Legalize any unresolved type materializations.
2465  LogicalResult legalizeUnresolvedMaterializations(
2466  ConversionPatternRewriter &rewriter,
2467  ConversionPatternRewriterImpl &rewriterImpl,
2468  std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
2469 
2470  /// Legalize an operation result that was marked as "erased".
2471  LogicalResult
2472  legalizeErasedResult(Operation *op, OpResult result,
2473  ConversionPatternRewriterImpl &rewriterImpl);
2474 
2475  /// Legalize an operation result that was replaced with a value of a different
2476  /// type.
2477  LogicalResult legalizeChangedResultType(
2478  Operation *op, OpResult result, Value newValue,
2479  const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2480  ConversionPatternRewriterImpl &rewriterImpl,
2481  const DenseMap<Value, SmallVector<Value>> &inverseMapping);
2482 
2483  /// Dialect conversion configuration.
2484  ConversionConfig config;
2485 
2486  /// The legalizer to use when converting operations.
2487  OperationLegalizer opLegalizer;
2488 
2489  /// The conversion mode to use when legalizing operations.
2490  OpConversionMode mode;
2491 };
2492 } // namespace mlir
2493 
2494 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2495  Operation *op) {
2496  // Legalize the given operation.
2497  if (failed(opLegalizer.legalize(op, rewriter))) {
2498  // Handle the case of a failed conversion for each of the different modes.
2499  // Full conversions expect all operations to be converted.
2500  if (mode == OpConversionMode::Full)
2501  return op->emitError()
2502  << "failed to legalize operation '" << op->getName() << "'";
2503  // Partial conversions allow conversions to fail iff the operation was not
2504  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2505  // set, non-legalizable ops are added to that set.
2506  if (mode == OpConversionMode::Partial) {
2507  if (opLegalizer.isIllegal(op))
2508  return op->emitError()
2509  << "failed to legalize operation '" << op->getName()
2510  << "' that was explicitly marked illegal";
2511  if (config.unlegalizedOps)
2512  config.unlegalizedOps->insert(op);
2513  }
2514  } else if (mode == OpConversionMode::Analysis) {
2515  // Analysis conversions don't fail if any operations fail to legalize,
2516  // they are only interested in the operations that were successfully
2517  // legalized.
2518  if (config.legalizableOps)
2519  config.legalizableOps->insert(op);
2520  }
2521  return success();
2522 }
2523 
2525  if (ops.empty())
2526  return success();
2527  const ConversionTarget &target = opLegalizer.getTarget();
2528 
2529  // Compute the set of operations and blocks to convert.
2530  SmallVector<Operation *> toConvert;
2531  for (auto *op : ops) {
2533  [&](Operation *op) {
2534  toConvert.push_back(op);
2535  // Don't check this operation's children for conversion if the
2536  // operation is recursively legal.
2537  auto legalityInfo = target.isLegal(op);
2538  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2539  return WalkResult::skip();
2540  return WalkResult::advance();
2541  });
2542  }
2543 
2544  // Convert each operation and discard rewrites on failure.
2545  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2546  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2547 
2548  for (auto *op : toConvert)
2549  if (failed(convert(rewriter, op)))
2550  return rewriterImpl.undoRewrites(), failure();
2551 
2552  // Now that all of the operations have been converted, finalize the conversion
2553  // process to ensure any lingering conversion artifacts are cleaned up and
2554  // legalized.
2555  if (failed(finalize(rewriter)))
2556  return rewriterImpl.undoRewrites(), failure();
2557 
2558  // After a successful conversion, apply rewrites if this is not an analysis
2559  // conversion.
2560  if (mode == OpConversionMode::Analysis) {
2561  rewriterImpl.undoRewrites();
2562  } else {
2563  rewriterImpl.applyRewrites();
2564  }
2565  return success();
2566 }
2567 
2568 LogicalResult
2569 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2570  std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
2571  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2572  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2573  inverseMapping)) ||
2574  failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2575  return failure();
2576 
2577  // Process requested operation replacements.
2578  for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
2579  auto *opReplacement =
2580  dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
2581  if (!opReplacement || !opReplacement->hasChangedResults())
2582  continue;
2583  Operation *op = opReplacement->getOperation();
2584  for (OpResult result : op->getResults()) {
2585  Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2586 
2587  // If the operation result was replaced with null, all of the uses of this
2588  // value should be replaced.
2589  if (!newValue) {
2590  if (failed(legalizeErasedResult(op, result, rewriterImpl)))
2591  return failure();
2592  continue;
2593  }
2594 
2595  // Otherwise, check to see if the type of the result changed.
2596  if (result.getType() == newValue.getType())
2597  continue;
2598 
2599  // Compute the inverse mapping only if it is really needed.
2600  if (!inverseMapping)
2601  inverseMapping = rewriterImpl.mapping.getInverse();
2602 
2603  // Legalize this result.
2604  rewriter.setInsertionPoint(op);
2605  if (failed(legalizeChangedResultType(
2606  op, result, newValue, opReplacement->getConverter(), rewriter,
2607  rewriterImpl, *inverseMapping)))
2608  return failure();
2609  }
2610  }
2611  return success();
2612 }
2613 
2614 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2615  ConversionPatternRewriter &rewriter,
2616  ConversionPatternRewriterImpl &rewriterImpl) {
2617  // Functor used to check if all users of a value will be dead after
2618  // conversion.
2619  auto findLiveUser = [&](Value val) {
2620  auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2621  return rewriterImpl.isOpIgnored(user);
2622  });
2623  return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2624  };
2625  // Note: `rewrites` may be reallocated as the loop is running.
2626  for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
2627  ++i) {
2628  auto &rewrite = rewriterImpl.rewrites[i];
2629  if (auto *blockTypeConversionRewrite =
2630  dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
2631  if (failed(blockTypeConversionRewrite->materializeLiveConversions(
2632  findLiveUser)))
2633  return failure();
2634  }
2635  return success();
2636 }
2637 
2638 /// Replace the results of a materialization operation with the given values.
2639 static void
2641  ResultRange matResults, ValueRange values,
2642  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2643  matResults.replaceAllUsesWith(values);
2644 
2645  // For each of the materialization results, update the inverse mappings to
2646  // point to the replacement values.
2647  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
2648  auto inverseMapIt = inverseMapping.find(matResult);
2649  if (inverseMapIt == inverseMapping.end())
2650  continue;
2651 
2652  // Update the reverse mapping, or remove the mapping if we couldn't update
2653  // it. Not being able to update signals that the mapping would have become
2654  // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
2655  // propagated through temporary materializations. We simply drop the
2656  // mapping, and let the post-conversion replacement logic handle updating
2657  // uses.
2658  for (Value inverseMapVal : inverseMapIt->second)
2659  if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
2660  rewriterImpl.mapping.erase(inverseMapVal);
2661  }
2662 }
2663 
2664 /// Compute all of the unresolved materializations that will persist beyond the
2665 /// conversion process, and require inserting a proper user materialization for.
2668  &materializationOps,
2669  ConversionPatternRewriter &rewriter,
2670  ConversionPatternRewriterImpl &rewriterImpl,
2671  DenseMap<Value, SmallVector<Value>> &inverseMapping,
2672  SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
2673  // Helper function to check if the given value or a not yet materialized
2674  // replacement of the given value is live.
2675  // Note: `inverseMapping` maps from replaced values to original values.
2676  auto isLive = [&](Value value) {
2677  auto findFn = [&](Operation *user) {
2678  auto matIt = materializationOps.find(user);
2679  if (matIt != materializationOps.end())
2680  return !necessaryMaterializations.count(matIt->second);
2681  return rewriterImpl.isOpIgnored(user);
2682  };
2683  // A worklist is needed because a value may have gone through a chain of
2684  // replacements and each of the replaced values may have live users.
2685  SmallVector<Value> worklist;
2686  worklist.push_back(value);
2687  while (!worklist.empty()) {
2688  Value next = worklist.pop_back_val();
2689  if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
2690  return true;
2691  // This value may be replacing another value that has a live user.
2692  llvm::append_range(worklist, inverseMapping.lookup(next));
2693  }
2694  return false;
2695  };
2696 
2697  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
2698  [&](Value invalidRoot, Value value, Type type) {
2699  // Check to see if the input operation was remapped to a variant of the
2700  // output.
2701  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2702  if (remappedValue.getType() == type && remappedValue != invalidRoot)
2703  return remappedValue;
2704 
2705  // Check to see if the input is a materialization operation that
2706  // provides an inverse conversion. We just check blindly for
2707  // UnrealizedConversionCastOp here, but it has no effect on correctness.
2708  auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
2709  if (inputCastOp && inputCastOp->getNumOperands() == 1)
2710  return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2711  type);
2712 
2713  return Value();
2714  };
2715 
2717  for (auto &rewrite : rewriterImpl.rewrites) {
2718  auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
2719  if (!mat)
2720  continue;
2721  materializationOps.try_emplace(mat->getOperation(), mat);
2722  worklist.insert(mat);
2723  }
2724  while (!worklist.empty()) {
2725  UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
2726  UnrealizedConversionCastOp op = mat->getOperation();
2727 
2728  // We currently only handle target materializations here.
2729  assert(op->getNumResults() == 1 && "unexpected materialization type");
2730  OpResult opResult = op->getOpResult(0);
2731  Type outputType = opResult.getType();
2732  Operation::operand_range inputOperands = op.getOperands();
2733 
2734  // Try to forward propagate operands for user conversion casts that result
2735  // in the input types of the current cast.
2736  for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
2737  auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2738  if (!castOp)
2739  continue;
2740  if (castOp->getResultTypes() == inputOperands.getTypes()) {
2741  replaceMaterialization(rewriterImpl, opResult, inputOperands,
2742  inverseMapping);
2743  necessaryMaterializations.remove(materializationOps.lookup(user));
2744  }
2745  }
2746 
2747  // Try to avoid materializing a resolved materialization if possible.
2748  // Handle the case of a 1-1 materialization.
2749  if (inputOperands.size() == 1) {
2750  // Check to see if the input operation was remapped to a variant of the
2751  // output.
2752  Value remappedValue =
2753  lookupRemappedValue(opResult, inputOperands[0], outputType);
2754  if (remappedValue && remappedValue != opResult) {
2755  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2756  inverseMapping);
2757  necessaryMaterializations.remove(mat);
2758  continue;
2759  }
2760  } else {
2761  // TODO: Avoid materializing other types of conversions here.
2762  }
2763 
2764  // If the materialization does not have any live users, we don't need to
2765  // generate a user materialization for it.
2766  bool isMaterializationLive = isLive(opResult);
2767  if (!isMaterializationLive)
2768  continue;
2769  if (!necessaryMaterializations.insert(mat))
2770  continue;
2771 
2772  // Reprocess input materializations to see if they have an updated status.
2773  for (Value input : inputOperands) {
2774  if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2775  if (auto *mat = materializationOps.lookup(parentOp))
2776  worklist.insert(mat);
2777  }
2778  }
2779  }
2780 }
2781 
2782 /// Legalize the given unresolved materialization. Returns success if the
2783 /// materialization was legalized, failure otherise.
2785  UnresolvedMaterializationRewrite &mat,
2787  &materializationOps,
2788  ConversionPatternRewriter &rewriter,
2789  ConversionPatternRewriterImpl &rewriterImpl,
2790  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2791  auto findLiveUser = [&](auto &&users) {
2792  auto liveUserIt = llvm::find_if_not(
2793  users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
2794  return liveUserIt == users.end() ? nullptr : *liveUserIt;
2795  };
2796 
2797  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
2798  [&](Value value, Type type) {
2799  // Check to see if the input operation was remapped to a variant of the
2800  // output.
2801  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2802  if (remappedValue.getType() == type)
2803  return remappedValue;
2804  return Value();
2805  };
2806 
2807  UnrealizedConversionCastOp op = mat.getOperation();
2808  if (!rewriterImpl.ignoredOps.insert(op))
2809  return success();
2810 
2811  // We currently only handle target materializations here.
2812  OpResult opResult = op->getOpResult(0);
2813  Operation::operand_range inputOperands = op.getOperands();
2814  Type outputType = opResult.getType();
2815 
2816  // If any input to this materialization is another materialization, resolve
2817  // the input first.
2818  for (Value value : op->getOperands()) {
2819  auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
2820  if (!valueCast)
2821  continue;
2822 
2823  auto matIt = materializationOps.find(valueCast);
2824  if (matIt != materializationOps.end())
2826  *matIt->second, materializationOps, rewriter, rewriterImpl,
2827  inverseMapping)))
2828  return failure();
2829  }
2830 
2831  // Perform a last ditch attempt to avoid materializing a resolved
2832  // materialization if possible.
2833  // Handle the case of a 1-1 materialization.
2834  if (inputOperands.size() == 1) {
2835  // Check to see if the input operation was remapped to a variant of the
2836  // output.
2837  Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2838  if (remappedValue && remappedValue != opResult) {
2839  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2840  inverseMapping);
2841  return success();
2842  }
2843  } else {
2844  // TODO: Avoid materializing other types of conversions here.
2845  }
2846 
2847  // Try to materialize the conversion.
2848  if (const TypeConverter *converter = mat.getConverter()) {
2849  rewriter.setInsertionPoint(op);
2850  Value newMaterialization;
2851  switch (mat.getMaterializationKind()) {
2853  // Try to materialize an argument conversion.
2854  newMaterialization = converter->materializeArgumentConversion(
2855  rewriter, op->getLoc(), outputType, inputOperands);
2856  if (newMaterialization)
2857  break;
2858  // If an argument materialization failed, fallback to trying a target
2859  // materialization.
2860  [[fallthrough]];
2861  case MaterializationKind::Target:
2862  newMaterialization = converter->materializeTargetConversion(
2863  rewriter, op->getLoc(), outputType, inputOperands);
2864  break;
2865  }
2866  if (newMaterialization) {
2867  assert(newMaterialization.getType() == outputType &&
2868  "materialization callback produced value of incorrect type");
2869  replaceMaterialization(rewriterImpl, opResult, newMaterialization,
2870  inverseMapping);
2871  return success();
2872  }
2873  }
2874 
2876  << "failed to legalize unresolved materialization "
2877  "from "
2878  << inputOperands.getTypes() << " to " << outputType
2879  << " that remained live after conversion";
2880  if (Operation *liveUser = findLiveUser(op->getUsers())) {
2881  diag.attachNote(liveUser->getLoc())
2882  << "see existing live user here: " << *liveUser;
2883  }
2884  return failure();
2885 }
2886 
2887 LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2888  ConversionPatternRewriter &rewriter,
2889  ConversionPatternRewriterImpl &rewriterImpl,
2890  std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
2891  inverseMapping = rewriterImpl.mapping.getInverse();
2892 
2893  // As an initial step, compute all of the inserted materializations that we
2894  // expect to persist beyond the conversion process.
2896  SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
2897  computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
2898  *inverseMapping, necessaryMaterializations);
2899 
2900  // Once computed, legalize any necessary materializations.
2901  for (auto *mat : necessaryMaterializations) {
2903  *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
2904  return failure();
2905  }
2906  return success();
2907 }
2908 
2909 LogicalResult OperationConverter::legalizeErasedResult(
2910  Operation *op, OpResult result,
2911  ConversionPatternRewriterImpl &rewriterImpl) {
2912  // If the operation result was replaced with null, all of the uses of this
2913  // value should be replaced.
2914  auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2915  return rewriterImpl.isOpIgnored(user);
2916  });
2917  if (liveUserIt != result.user_end()) {
2918  InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2919  << op->getName() << "' marked as erased";
2920  diag.attachNote(liveUserIt->getLoc())
2921  << "found live user of result #" << result.getResultNumber() << ": "
2922  << *liveUserIt;
2923  return failure();
2924  }
2925  return success();
2926 }
2927 
2928 /// Finds a user of the given value, or of any other value that the given value
2929 /// replaced, that was not replaced in the conversion process.
2931  Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2932  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2933  SmallVector<Value> worklist(1, initialValue);
2934  while (!worklist.empty()) {
2935  Value value = worklist.pop_back_val();
2936 
2937  // Walk the users of this value to see if there are any live users that
2938  // weren't replaced during conversion.
2939  auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2940  return rewriterImpl.isOpIgnored(user);
2941  });
2942  if (liveUserIt != value.user_end())
2943  return *liveUserIt;
2944  auto mapIt = inverseMapping.find(value);
2945  if (mapIt != inverseMapping.end())
2946  worklist.append(mapIt->second);
2947  }
2948  return nullptr;
2949 }
2950 
2951 LogicalResult OperationConverter::legalizeChangedResultType(
2952  Operation *op, OpResult result, Value newValue,
2953  const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2954  ConversionPatternRewriterImpl &rewriterImpl,
2955  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2956  Operation *liveUser =
2957  findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2958  if (!liveUser)
2959  return success();
2960 
2961  // Functor used to emit a conversion error for a failed materialization.
2962  auto emitConversionError = [&] {
2964  << "failed to materialize conversion for result #"
2965  << result.getResultNumber() << " of operation '"
2966  << op->getName()
2967  << "' that remained live after conversion";
2968  diag.attachNote(liveUser->getLoc())
2969  << "see existing live user here: " << *liveUser;
2970  return failure();
2971  };
2972 
2973  // If the replacement has a type converter, attempt to materialize a
2974  // conversion back to the original type.
2975  if (!replConverter)
2976  return emitConversionError();
2977 
2978  // Materialize a conversion for this live result value.
2979  Type resultType = result.getType();
2980  Value convertedValue = replConverter->materializeSourceConversion(
2981  rewriter, op->getLoc(), resultType, newValue);
2982  if (!convertedValue)
2983  return emitConversionError();
2984 
2985  rewriterImpl.mapping.map(result, convertedValue);
2986  return success();
2987 }
2988 
2989 //===----------------------------------------------------------------------===//
2990 // Type Conversion
2991 //===----------------------------------------------------------------------===//
2992 
2994  ArrayRef<Type> types) {
2995  assert(!types.empty() && "expected valid types");
2996  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2997  addInputs(types);
2998 }
2999 
3001  assert(!types.empty() &&
3002  "1->0 type remappings don't need to be added explicitly");
3003  argTypes.append(types.begin(), types.end());
3004 }
3005 
3006 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3007  unsigned newInputNo,
3008  unsigned newInputCount) {
3009  assert(!remappedInputs[origInputNo] && "input has already been remapped");
3010  assert(newInputCount != 0 && "expected valid input count");
3011  remappedInputs[origInputNo] =
3012  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
3013 }
3014 
3016  Value replacementValue) {
3017  assert(!remappedInputs[origInputNo] && "input has already been remapped");
3018  remappedInputs[origInputNo] =
3019  InputMapping{origInputNo, /*size=*/0, replacementValue};
3020 }
3021 
3023  SmallVectorImpl<Type> &results) const {
3024  {
3025  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3026  std::defer_lock);
3028  cacheReadLock.lock();
3029  auto existingIt = cachedDirectConversions.find(t);
3030  if (existingIt != cachedDirectConversions.end()) {
3031  if (existingIt->second)
3032  results.push_back(existingIt->second);
3033  return success(existingIt->second != nullptr);
3034  }
3035  auto multiIt = cachedMultiConversions.find(t);
3036  if (multiIt != cachedMultiConversions.end()) {
3037  results.append(multiIt->second.begin(), multiIt->second.end());
3038  return success();
3039  }
3040  }
3041  // Walk the added converters in reverse order to apply the most recently
3042  // registered first.
3043  size_t currentCount = results.size();
3044 
3045  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3046  std::defer_lock);
3047 
3048  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
3049  if (std::optional<LogicalResult> result = converter(t, results)) {
3051  cacheWriteLock.lock();
3052  if (!succeeded(*result)) {
3053  cachedDirectConversions.try_emplace(t, nullptr);
3054  return failure();
3055  }
3056  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3057  if (newTypes.size() == 1)
3058  cachedDirectConversions.try_emplace(t, newTypes.front());
3059  else
3060  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3061  return success();
3062  }
3063  }
3064  return failure();
3065 }
3066 
3068  // Use the multi-type result version to convert the type.
3069  SmallVector<Type, 1> results;
3070  if (failed(convertType(t, results)))
3071  return nullptr;
3072 
3073  // Check to ensure that only one type was produced.
3074  return results.size() == 1 ? results.front() : nullptr;
3075 }
3076 
3077 LogicalResult
3079  SmallVectorImpl<Type> &results) const {
3080  for (Type type : types)
3081  if (failed(convertType(type, results)))
3082  return failure();
3083  return success();
3084 }
3085 
3086 bool TypeConverter::isLegal(Type type) const {
3087  return convertType(type) == type;
3088 }
3090  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
3091 }
3092 
3093 bool TypeConverter::isLegal(Region *region) const {
3094  return llvm::all_of(*region, [this](Block &block) {
3095  return isLegal(block.getArgumentTypes());
3096  });
3097 }
3098 
3099 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3100  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
3101 }
3102 
3103 LogicalResult
3105  SignatureConversion &result) const {
3106  // Try to convert the given input type.
3107  SmallVector<Type, 1> convertedTypes;
3108  if (failed(convertType(type, convertedTypes)))
3109  return failure();
3110 
3111  // If this argument is being dropped, there is nothing left to do.
3112  if (convertedTypes.empty())
3113  return success();
3114 
3115  // Otherwise, add the new inputs.
3116  result.addInputs(inputNo, convertedTypes);
3117  return success();
3118 }
3119 LogicalResult
3121  SignatureConversion &result,
3122  unsigned origInputOffset) const {
3123  for (unsigned i = 0, e = types.size(); i != e; ++i)
3124  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3125  return failure();
3126  return success();
3127 }
3128 
3129 Value TypeConverter::materializeConversion(
3130  ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
3131  Location loc, Type resultType, ValueRange inputs) const {
3132  for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
3133  if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
3134  return *result;
3135  return nullptr;
3136 }
3137 
3138 std::optional<TypeConverter::SignatureConversion>
3140  SignatureConversion conversion(block->getNumArguments());
3141  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3142  return std::nullopt;
3143  return conversion;
3144 }
3145 
3146 //===----------------------------------------------------------------------===//
3147 // Type attribute conversion
3148 //===----------------------------------------------------------------------===//
3151  return AttributeConversionResult(attr, resultTag);
3152 }
3153 
3156  return AttributeConversionResult(nullptr, naTag);
3157 }
3158 
3161  return AttributeConversionResult(nullptr, abortTag);
3162 }
3163 
3165  return impl.getInt() == resultTag;
3166 }
3167 
3169  return impl.getInt() == naTag;
3170 }
3171 
3173  return impl.getInt() == abortTag;
3174 }
3175 
3177  assert(hasResult() && "Cannot get result from N/A or abort");
3178  return impl.getPointer();
3179 }
3180 
3181 std::optional<Attribute>
3183  for (const TypeAttributeConversionCallbackFn &fn :
3184  llvm::reverse(typeAttributeConversions)) {
3185  AttributeConversionResult res = fn(type, attr);
3186  if (res.hasResult())
3187  return res.getResult();
3188  if (res.isAbort())
3189  return std::nullopt;
3190  }
3191  return std::nullopt;
3192 }
3193 
3194 //===----------------------------------------------------------------------===//
3195 // FunctionOpInterfaceSignatureConversion
3196 //===----------------------------------------------------------------------===//
3197 
3198 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3199  const TypeConverter &typeConverter,
3200  ConversionPatternRewriter &rewriter) {
3201  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3202  if (!type)
3203  return failure();
3204 
3205  // Convert the original function types.
3206  TypeConverter::SignatureConversion result(type.getNumInputs());
3207  SmallVector<Type, 1> newResults;
3208  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3209  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3210  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3211  typeConverter, &result)))
3212  return failure();
3213 
3214  // Update the function signature in-place.
3215  auto newType = FunctionType::get(rewriter.getContext(),
3216  result.getConvertedTypes(), newResults);
3217 
3218  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3219 
3220  return success();
3221 }
3222 
3223 /// Create a default conversion pattern that rewrites the type signature of a
3224 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3225 /// represent their type.
3226 namespace {
3227 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3228  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3229  MLIRContext *ctx,
3230  const TypeConverter &converter)
3231  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3232 
3233  LogicalResult
3234  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3235  ConversionPatternRewriter &rewriter) const override {
3236  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3237  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3238  }
3239 };
3240 
3241 struct AnyFunctionOpInterfaceSignatureConversion
3242  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3244 
3245  LogicalResult
3246  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3247  ConversionPatternRewriter &rewriter) const override {
3248  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3249  }
3250 };
3251 } // namespace
3252 
3253 FailureOr<Operation *>
3255  const TypeConverter &converter,
3256  ConversionPatternRewriter &rewriter) {
3257  assert(op && "Invalid op");
3258  Location loc = op->getLoc();
3259  if (converter.isLegal(op))
3260  return rewriter.notifyMatchFailure(loc, "op already legal");
3261 
3262  OperationState newOp(loc, op->getName());
3263  newOp.addOperands(operands);
3264 
3265  SmallVector<Type> newResultTypes;
3266  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3267  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3268 
3269  newOp.addTypes(newResultTypes);
3270  newOp.addAttributes(op->getAttrs());
3271  return rewriter.create(newOp);
3272 }
3273 
3275  StringRef functionLikeOpName, RewritePatternSet &patterns,
3276  const TypeConverter &converter) {
3277  patterns.add<FunctionOpInterfaceSignatureConversion>(
3278  functionLikeOpName, patterns.getContext(), converter);
3279 }
3280 
3282  RewritePatternSet &patterns, const TypeConverter &converter) {
3283  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3284  converter, patterns.getContext());
3285 }
3286 
3287 //===----------------------------------------------------------------------===//
3288 // ConversionTarget
3289 //===----------------------------------------------------------------------===//
3290 
3292  LegalizationAction action) {
3293  legalOperations[op].action = action;
3294 }
3295 
3297  LegalizationAction action) {
3298  for (StringRef dialect : dialectNames)
3299  legalDialects[dialect] = action;
3300 }
3301 
3303  -> std::optional<LegalizationAction> {
3304  std::optional<LegalizationInfo> info = getOpInfo(op);
3305  return info ? info->action : std::optional<LegalizationAction>();
3306 }
3307 
3309  -> std::optional<LegalOpDetails> {
3310  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3311  if (!info)
3312  return std::nullopt;
3313 
3314  // Returns true if this operation instance is known to be legal.
3315  auto isOpLegal = [&] {
3316  // Handle dynamic legality either with the provided legality function.
3317  if (info->action == LegalizationAction::Dynamic) {
3318  std::optional<bool> result = info->legalityFn(op);
3319  if (result)
3320  return *result;
3321  }
3322 
3323  // Otherwise, the operation is only legal if it was marked 'Legal'.
3324  return info->action == LegalizationAction::Legal;
3325  };
3326  if (!isOpLegal())
3327  return std::nullopt;
3328 
3329  // This operation is legal, compute any additional legality information.
3330  LegalOpDetails legalityDetails;
3331  if (info->isRecursivelyLegal) {
3332  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3333  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3334  legalityDetails.isRecursivelyLegal =
3335  legalityFnIt->second(op).value_or(true);
3336  } else {
3337  legalityDetails.isRecursivelyLegal = true;
3338  }
3339  }
3340  return legalityDetails;
3341 }
3342 
3344  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3345  if (!info)
3346  return false;
3347 
3348  if (info->action == LegalizationAction::Dynamic) {
3349  std::optional<bool> result = info->legalityFn(op);
3350  if (!result)
3351  return false;
3352 
3353  return !(*result);
3354  }
3355 
3356  return info->action == LegalizationAction::Illegal;
3357 }
3358 
3362  if (!oldCallback)
3363  return newCallback;
3364 
3365  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3366  Operation *op) -> std::optional<bool> {
3367  if (std::optional<bool> result = newCl(op))
3368  return *result;
3369 
3370  return oldCl(op);
3371  };
3372  return chain;
3373 }
3374 
3375 void ConversionTarget::setLegalityCallback(
3376  OperationName name, const DynamicLegalityCallbackFn &callback) {
3377  assert(callback && "expected valid legality callback");
3378  auto *infoIt = legalOperations.find(name);
3379  assert(infoIt != legalOperations.end() &&
3380  infoIt->second.action == LegalizationAction::Dynamic &&
3381  "expected operation to already be marked as dynamically legal");
3382  infoIt->second.legalityFn =
3383  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3384 }
3385 
3387  OperationName name, const DynamicLegalityCallbackFn &callback) {
3388  auto *infoIt = legalOperations.find(name);
3389  assert(infoIt != legalOperations.end() &&
3390  infoIt->second.action != LegalizationAction::Illegal &&
3391  "expected operation to already be marked as legal");
3392  infoIt->second.isRecursivelyLegal = true;
3393  if (callback)
3394  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3395  std::move(opRecursiveLegalityFns[name]), callback);
3396  else
3397  opRecursiveLegalityFns.erase(name);
3398 }
3399 
3400 void ConversionTarget::setLegalityCallback(
3401  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3402  assert(callback && "expected valid legality callback");
3403  for (StringRef dialect : dialects)
3404  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3405  std::move(dialectLegalityFns[dialect]), callback);
3406 }
3407 
3408 void ConversionTarget::setLegalityCallback(
3409  const DynamicLegalityCallbackFn &callback) {
3410  assert(callback && "expected valid legality callback");
3411  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3412 }
3413 
3414 auto ConversionTarget::getOpInfo(OperationName op) const
3415  -> std::optional<LegalizationInfo> {
3416  // Check for info for this specific operation.
3417  const auto *it = legalOperations.find(op);
3418  if (it != legalOperations.end())
3419  return it->second;
3420  // Check for info for the parent dialect.
3421  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3422  if (dialectIt != legalDialects.end()) {
3423  DynamicLegalityCallbackFn callback;
3424  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3425  if (dialectFn != dialectLegalityFns.end())
3426  callback = dialectFn->second;
3427  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3428  callback};
3429  }
3430  // Otherwise, check if we mark unknown operations as dynamic.
3431  if (unknownLegalityFn)
3432  return LegalizationInfo{LegalizationAction::Dynamic,
3433  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3434  return std::nullopt;
3435 }
3436 
3437 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3438 //===----------------------------------------------------------------------===//
3439 // PDL Configuration
3440 //===----------------------------------------------------------------------===//
3441 
3443  auto &rewriterImpl =
3444  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3445  rewriterImpl.currentTypeConverter = getTypeConverter();
3446 }
3447 
3449  auto &rewriterImpl =
3450  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3451  rewriterImpl.currentTypeConverter = nullptr;
3452 }
3453 
3454 /// Remap the given value using the rewriter and the type converter in the
3455 /// provided config.
3456 static FailureOr<SmallVector<Value>>
3458  SmallVector<Value> mappedValues;
3459  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3460  return failure();
3461  return std::move(mappedValues);
3462 }
3463 
3465  patterns.getPDLPatterns().registerRewriteFunction(
3466  "convertValue",
3467  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3468  auto results = pdllConvertValues(
3469  static_cast<ConversionPatternRewriter &>(rewriter), value);
3470  if (failed(results))
3471  return failure();
3472  return results->front();
3473  });
3474  patterns.getPDLPatterns().registerRewriteFunction(
3475  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3476  return pdllConvertValues(
3477  static_cast<ConversionPatternRewriter &>(rewriter), values);
3478  });
3479  patterns.getPDLPatterns().registerRewriteFunction(
3480  "convertType",
3481  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3482  auto &rewriterImpl =
3483  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3484  if (const TypeConverter *converter =
3485  rewriterImpl.currentTypeConverter) {
3486  if (Type newType = converter->convertType(type))
3487  return newType;
3488  return failure();
3489  }
3490  return type;
3491  });
3492  patterns.getPDLPatterns().registerRewriteFunction(
3493  "convertTypes",
3494  [](PatternRewriter &rewriter,
3495  TypeRange types) -> FailureOr<SmallVector<Type>> {
3496  auto &rewriterImpl =
3497  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3498  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3499  if (!converter)
3500  return SmallVector<Type>(types);
3501 
3502  SmallVector<Type> remappedTypes;
3503  if (failed(converter->convertTypes(types, remappedTypes)))
3504  return failure();
3505  return std::move(remappedTypes);
3506  });
3507 }
3508 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3509 
3510 //===----------------------------------------------------------------------===//
3511 // Op Conversion Entry Points
3512 //===----------------------------------------------------------------------===//
3513 
3514 //===----------------------------------------------------------------------===//
3515 // Partial Conversion
3516 
3518  ArrayRef<Operation *> ops, const ConversionTarget &target,
3519  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3520  OperationConverter opConverter(target, patterns, config,
3521  OpConversionMode::Partial);
3522  return opConverter.convertOperations(ops);
3523 }
3524 LogicalResult
3526  const FrozenRewritePatternSet &patterns,
3527  ConversionConfig config) {
3528  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3529 }
3530 
3531 //===----------------------------------------------------------------------===//
3532 // Full Conversion
3533 
3535  const ConversionTarget &target,
3536  const FrozenRewritePatternSet &patterns,
3537  ConversionConfig config) {
3538  OperationConverter opConverter(target, patterns, config,
3539  OpConversionMode::Full);
3540  return opConverter.convertOperations(ops);
3541 }
3543  const ConversionTarget &target,
3544  const FrozenRewritePatternSet &patterns,
3545  ConversionConfig config) {
3546  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3547 }
3548 
3549 //===----------------------------------------------------------------------===//
3550 // Analysis Conversion
3551 
3554  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3555  OperationConverter opConverter(target, patterns, config,
3556  OpConversionMode::Analysis);
3557  return opConverter.convertOperations(ops);
3558 }
3559 LogicalResult
3561  const FrozenRewritePatternSet &patterns,
3562  ConversionConfig config) {
3563  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3564 }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(UnresolvedMaterializationRewrite &mat, DenseMap< Operation *, UnresolvedMaterializationRewrite * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Legalize the given unresolved materialization.
static RewriteTy * findSingleRewrite(R &&rewrites, Block *block)
Find the single rewrite object of the specified type and block among the given rewrites.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, ResultRange matResults, ValueRange values, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Replace the results of a materialization operation with the given values.
static void computeNecessaryMaterializations(DenseMap< Operation *, UnresolvedMaterializationRewrite * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping, SetVector< UnresolvedMaterializationRewrite * > &necessaryMaterializations)
Compute all of the unresolved materializations that will persist beyond the conversion process,...
static bool hasRewrite(R &&rewrites, Operation *op)
Return "true" if there is an operation rewrite that matches the specified rewrite type and operation ...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
bool empty()
Definition: Block.h:146
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:93
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator end()
Definition: Block.h:142
iterator begin()
Definition: Block.h:141
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:329
Block::iterator getPoint() const
Definition: Builders.h:342
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:339
Block * getBlock() const
Definition: Builders.h:341
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:479
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
OpResult getOpResult(unsigned idx)
Definition: Operation.h:416
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:605
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:892
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:281
MLIRContext * getContext() const
Definition: PatternMatch.h:823
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_iterator user_end() const
Definition: Value.h:227
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition: Argument.h:64
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a single value that remaps an existing signature input...
A rewriter that keeps track of erased ops and blocks.
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
DenseSet< void * > erased
Pointers to all erased operations and blocks.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
Value buildUnresolvedMaterialization(MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, const TypeConverter *converter)
Build an unresolved materialization operation given an output type and set of input operands.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc, ValueRange inputs, Type outputType, const TypeConverter *converter)
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ValueRange newValues)
Notifies that an op is about to be replaced with the given values.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, const TypeConverter *converter)
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.