MLIR  19.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Iterators.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::detail;
29 
30 #define DEBUG_TYPE "dialect-conversion"
31 
32 /// A utility function to log a successful result for the given reason.
33 template <typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
35  LLVM_DEBUG({
36  os.unindent();
37  os.startLine() << "} -> SUCCESS";
38  if (!fmt.empty())
39  os.getOStream() << " : "
40  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41  os.getOStream() << "\n";
42  });
43 }
44 
45 /// A utility function to log a failure result for the given reason.
46 template <typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
48  LLVM_DEBUG({
49  os.unindent();
50  os.startLine() << "} -> FAILURE : "
51  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
52  << "\n";
53  });
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // ConversionValueMapping
58 //===----------------------------------------------------------------------===//
59 
60 namespace {
61 /// This class wraps a IRMapping to provide recursive lookup
62 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
63 struct ConversionValueMapping {
64  /// Lookup a mapped value within the map. If a mapping for the provided value
65  /// does not exist then return the provided value. If `desiredType` is
66  /// non-null, returns the most recently mapped value with that type. If an
67  /// operand of that type does not exist, defaults to normal behavior.
68  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
69 
70  /// Lookup a mapped value within the map, or return null if a mapping does not
71  /// exist. If a mapping exists, this follows the same behavior of
72  /// `lookupOrDefault`.
73  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
74 
75  /// Map a value to the one provided.
76  void map(Value oldVal, Value newVal) {
77  LLVM_DEBUG({
78  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
79  assert(it != oldVal && "inserting cyclic mapping");
80  });
81  mapping.map(oldVal, newVal);
82  }
83 
84  /// Try to map a value to the one provided. Returns false if a transitive
85  /// mapping from the new value to the old value already exists, true if the
86  /// map was updated.
87  bool tryMap(Value oldVal, Value newVal);
88 
89  /// Drop the last mapping for the given value.
90  void erase(Value value) { mapping.erase(value); }
91 
92  /// Returns the inverse raw value mapping (without recursive query support).
93  DenseMap<Value, SmallVector<Value>> getInverse() const {
95  for (auto &it : mapping.getValueMap())
96  inverse[it.second].push_back(it.first);
97  return inverse;
98  }
99 
100 private:
101  /// Current value mappings.
102  IRMapping mapping;
103 };
104 } // namespace
105 
106 Value ConversionValueMapping::lookupOrDefault(Value from,
107  Type desiredType) const {
108  // If there was no desired type, simply find the leaf value.
109  if (!desiredType) {
110  // If this value had a valid mapping, unmap that value as well in the case
111  // that it was also replaced.
112  while (auto mappedValue = mapping.lookupOrNull(from))
113  from = mappedValue;
114  return from;
115  }
116 
117  // Otherwise, try to find the deepest value that has the desired type.
118  Value desiredValue;
119  do {
120  if (from.getType() == desiredType)
121  desiredValue = from;
122 
123  Value mappedValue = mapping.lookupOrNull(from);
124  if (!mappedValue)
125  break;
126  from = mappedValue;
127  } while (true);
128 
129  // If the desired value was found use it, otherwise default to the leaf value.
130  return desiredValue ? desiredValue : from;
131 }
132 
133 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
134  Value result = lookupOrDefault(from, desiredType);
135  if (result == from || (desiredType && result.getType() != desiredType))
136  return nullptr;
137  return result;
138 }
139 
140 bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
141  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
142  if (it == oldVal)
143  return false;
144  map(oldVal, newVal);
145  return true;
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // Rewriter and Translation State
150 //===----------------------------------------------------------------------===//
151 namespace {
152 /// This class contains a snapshot of the current conversion rewriter state.
153 /// This is useful when saving and undoing a set of rewrites.
154 struct RewriterState {
155  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
156  unsigned numReplacedOps)
157  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
158  numReplacedOps(numReplacedOps) {}
159 
160  /// The current number of rewrites performed.
161  unsigned numRewrites;
162 
163  /// The current number of ignored operations.
164  unsigned numIgnoredOperations;
165 
166  /// The current number of replaced ops that are scheduled for erasure.
167  unsigned numReplacedOps;
168 };
169 
170 //===----------------------------------------------------------------------===//
171 // IR rewrites
172 //===----------------------------------------------------------------------===//
173 
174 /// An IR rewrite that can be committed (upon success) or rolled back (upon
175 /// failure).
176 ///
177 /// The dialect conversion keeps track of IR modifications (requested by the
178 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
179 /// are directly applied to the IR as the rewriter API is used, some are applied
180 /// partially, and some are delayed until the `IRRewrite` objects are committed.
181 class IRRewrite {
182 public:
183  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
184  /// Enum values are ordered, so that they can be used in `classof`: first all
185  /// block rewrites, then all operation rewrites.
186  enum class Kind {
187  // Block rewrites
188  CreateBlock,
189  EraseBlock,
190  InlineBlock,
191  MoveBlock,
192  BlockTypeConversion,
193  ReplaceBlockArg,
194  // Operation rewrites
195  MoveOperation,
196  ModifyOperation,
197  ReplaceOperation,
198  CreateOperation,
199  UnresolvedMaterialization
200  };
201 
202  virtual ~IRRewrite() = default;
203 
204  /// Roll back the rewrite. Operations may be erased during rollback.
205  virtual void rollback() = 0;
206 
207  /// Commit the rewrite. At this point, it is certain that the dialect
208  /// conversion will succeed. All IR modifications, except for operation/block
209  /// erasure, must be performed through the given rewriter.
210  ///
211  /// Instead of erasing operations/blocks, they should merely be unlinked
212  /// commit phase and finally be erased during the cleanup phase. This is
213  /// because internal dialect conversion state (such as `mapping`) may still
214  /// be using them.
215  ///
216  /// Any IR modification that was already performed before the commit phase
217  /// (e.g., insertion of an op) must be communicated to the listener that may
218  /// be attached to the given rewriter.
219  virtual void commit(RewriterBase &rewriter) {}
220 
221  /// Cleanup operations/blocks. Cleanup is called after commit.
222  virtual void cleanup(RewriterBase &rewriter) {}
223 
224  Kind getKind() const { return kind; }
225 
226  static bool classof(const IRRewrite *rewrite) { return true; }
227 
228 protected:
229  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
230  : kind(kind), rewriterImpl(rewriterImpl) {}
231 
232  const ConversionConfig &getConfig() const;
233 
234  const Kind kind;
235  ConversionPatternRewriterImpl &rewriterImpl;
236 };
237 
238 /// A block rewrite.
239 class BlockRewrite : public IRRewrite {
240 public:
241  /// Return the block that this rewrite operates on.
242  Block *getBlock() const { return block; }
243 
244  static bool classof(const IRRewrite *rewrite) {
245  return rewrite->getKind() >= Kind::CreateBlock &&
246  rewrite->getKind() <= Kind::ReplaceBlockArg;
247  }
248 
249 protected:
250  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
251  Block *block)
252  : IRRewrite(kind, rewriterImpl), block(block) {}
253 
254  // The block that this rewrite operates on.
255  Block *block;
256 };
257 
258 /// Creation of a block. Block creations are immediately reflected in the IR.
259 /// There is no extra work to commit the rewrite. During rollback, the newly
260 /// created block is erased.
261 class CreateBlockRewrite : public BlockRewrite {
262 public:
263  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
264  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
265 
266  static bool classof(const IRRewrite *rewrite) {
267  return rewrite->getKind() == Kind::CreateBlock;
268  }
269 
270  void commit(RewriterBase &rewriter) override {
271  // The block was already created and inserted. Just inform the listener.
272  if (auto *listener = rewriter.getListener())
273  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
274  }
275 
276  void rollback() override {
277  // Unlink all of the operations within this block, they will be deleted
278  // separately.
279  auto &blockOps = block->getOperations();
280  while (!blockOps.empty())
281  blockOps.remove(blockOps.begin());
282  block->dropAllUses();
283  if (block->getParent())
284  block->erase();
285  else
286  delete block;
287  }
288 };
289 
290 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
291 /// blocks are immediately unlinked, but only erased during cleanup. This makes
292 /// it easier to rollback a block erasure: the block is simply inserted into its
293 /// original location.
294 class EraseBlockRewrite : public BlockRewrite {
295 public:
296  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
297  Region *region, Block *insertBeforeBlock)
298  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), region(region),
299  insertBeforeBlock(insertBeforeBlock) {}
300 
301  static bool classof(const IRRewrite *rewrite) {
302  return rewrite->getKind() == Kind::EraseBlock;
303  }
304 
305  ~EraseBlockRewrite() override {
306  assert(!block &&
307  "rewrite was neither rolled back nor committed/cleaned up");
308  }
309 
310  void rollback() override {
311  // The block (owned by this rewrite) was not actually erased yet. It was
312  // just unlinked. Put it back into its original position.
313  assert(block && "expected block");
314  auto &blockList = region->getBlocks();
315  Region::iterator before = insertBeforeBlock
316  ? Region::iterator(insertBeforeBlock)
317  : blockList.end();
318  blockList.insert(before, block);
319  block = nullptr;
320  }
321 
322  void commit(RewriterBase &rewriter) override {
323  // Erase the block.
324  assert(block && "expected block");
325  assert(block->empty() && "expected empty block");
326 
327  // Notify the listener that the block is about to be erased.
328  if (auto *listener =
329  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
330  listener->notifyBlockErased(block);
331  }
332 
333  void cleanup(RewriterBase &rewriter) override {
334  // Erase the block.
335  block->dropAllDefinedValueUses();
336  delete block;
337  block = nullptr;
338  }
339 
340 private:
341  // The region in which this block was previously contained.
342  Region *region;
343 
344  // The original successor of this block before it was unlinked. "nullptr" if
345  // this block was the only block in the region.
346  Block *insertBeforeBlock;
347 };
348 
349 /// Inlining of a block. This rewrite is immediately reflected in the IR.
350 /// Note: This rewrite represents only the inlining of the operations. The
351 /// erasure of the inlined block is a separate rewrite.
352 class InlineBlockRewrite : public BlockRewrite {
353 public:
354  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
355  Block *sourceBlock, Block::iterator before)
356  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
357  sourceBlock(sourceBlock),
358  firstInlinedInst(sourceBlock->empty() ? nullptr
359  : &sourceBlock->front()),
360  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
361  // If a listener is attached to the dialect conversion, ops must be moved
362  // one-by-one. When they are moved in bulk, notifications cannot be sent
363  // because the ops that used to be in the source block at the time of the
364  // inlining (before the "commit" phase) are unknown at the time when
365  // notifications are sent (which is during the "commit" phase).
366  assert(!getConfig().listener &&
367  "InlineBlockRewrite not supported if listener is attached");
368  }
369 
370  static bool classof(const IRRewrite *rewrite) {
371  return rewrite->getKind() == Kind::InlineBlock;
372  }
373 
374  void rollback() override {
375  // Put the operations from the destination block (owned by the rewrite)
376  // back into the source block.
377  if (firstInlinedInst) {
378  assert(lastInlinedInst && "expected operation");
379  sourceBlock->getOperations().splice(sourceBlock->begin(),
380  block->getOperations(),
381  Block::iterator(firstInlinedInst),
382  ++Block::iterator(lastInlinedInst));
383  }
384  }
385 
386 private:
387  // The block that originally contained the operations.
388  Block *sourceBlock;
389 
390  // The first inlined operation.
391  Operation *firstInlinedInst;
392 
393  // The last inlined operation.
394  Operation *lastInlinedInst;
395 };
396 
397 /// Moving of a block. This rewrite is immediately reflected in the IR.
398 class MoveBlockRewrite : public BlockRewrite {
399 public:
400  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
401  Region *region, Block *insertBeforeBlock)
402  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
403  insertBeforeBlock(insertBeforeBlock) {}
404 
405  static bool classof(const IRRewrite *rewrite) {
406  return rewrite->getKind() == Kind::MoveBlock;
407  }
408 
409  void commit(RewriterBase &rewriter) override {
410  // The block was already moved. Just inform the listener.
411  if (auto *listener = rewriter.getListener()) {
412  // Note: `previousIt` cannot be passed because this is a delayed
413  // notification and iterators into past IR state cannot be represented.
414  listener->notifyBlockInserted(block, /*previous=*/region,
415  /*previousIt=*/{});
416  }
417  }
418 
419  void rollback() override {
420  // Move the block back to its original position.
421  Region::iterator before =
422  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
423  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
424  }
425 
426 private:
427  // The region in which this block was previously contained.
428  Region *region;
429 
430  // The original successor of this block before it was moved. "nullptr" if
431  // this block was the only block in the region.
432  Block *insertBeforeBlock;
433 };
434 
435 /// This structure contains the information pertaining to an argument that has
436 /// been converted.
437 struct ConvertedArgInfo {
438  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
439  Value castValue = nullptr)
440  : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
441 
442  /// The start index of in the new argument list that contains arguments that
443  /// replace the original.
444  unsigned newArgIdx;
445 
446  /// The number of arguments that replaced the original argument.
447  unsigned newArgSize;
448 
449  /// The cast value that was created to cast from the new arguments to the
450  /// old. This only used if 'newArgSize' > 1.
451  Value castValue;
452 };
453 
454 /// Block type conversion. This rewrite is partially reflected in the IR.
455 class BlockTypeConversionRewrite : public BlockRewrite {
456 public:
457  BlockTypeConversionRewrite(
458  ConversionPatternRewriterImpl &rewriterImpl, Block *block,
459  Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
460  const TypeConverter *converter)
461  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
462  origBlock(origBlock), argInfo(argInfo), converter(converter) {}
463 
464  static bool classof(const IRRewrite *rewrite) {
465  return rewrite->getKind() == Kind::BlockTypeConversion;
466  }
467 
468  /// Materialize any necessary conversions for converted arguments that have
469  /// live users, using the provided `findLiveUser` to search for a user that
470  /// survives the conversion process.
472  materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
473 
474  void commit(RewriterBase &rewriter) override;
475 
476  void rollback() override;
477 
478 private:
479  /// The original block that was requested to have its signature converted.
480  Block *origBlock;
481 
482  /// The conversion information for each of the arguments. The information is
483  /// std::nullopt if the argument was dropped during conversion.
485 
486  /// The type converter used to convert the arguments.
487  const TypeConverter *converter;
488 };
489 
490 /// Replacing a block argument. This rewrite is not immediately reflected in the
491 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
492 /// until the rewrite is committed.
493 class ReplaceBlockArgRewrite : public BlockRewrite {
494 public:
495  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
496  Block *block, BlockArgument arg)
497  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
498 
499  static bool classof(const IRRewrite *rewrite) {
500  return rewrite->getKind() == Kind::ReplaceBlockArg;
501  }
502 
503  void commit(RewriterBase &rewriter) override;
504 
505  void rollback() override;
506 
507 private:
508  BlockArgument arg;
509 };
510 
511 /// An operation rewrite.
512 class OperationRewrite : public IRRewrite {
513 public:
514  /// Return the operation that this rewrite operates on.
515  Operation *getOperation() const { return op; }
516 
517  static bool classof(const IRRewrite *rewrite) {
518  return rewrite->getKind() >= Kind::MoveOperation &&
519  rewrite->getKind() <= Kind::UnresolvedMaterialization;
520  }
521 
522 protected:
523  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
524  Operation *op)
525  : IRRewrite(kind, rewriterImpl), op(op) {}
526 
527  // The operation that this rewrite operates on.
528  Operation *op;
529 };
530 
531 /// Moving of an operation. This rewrite is immediately reflected in the IR.
532 class MoveOperationRewrite : public OperationRewrite {
533 public:
534  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
535  Operation *op, Block *block, Operation *insertBeforeOp)
536  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
537  insertBeforeOp(insertBeforeOp) {}
538 
539  static bool classof(const IRRewrite *rewrite) {
540  return rewrite->getKind() == Kind::MoveOperation;
541  }
542 
543  void commit(RewriterBase &rewriter) override {
544  // The operation was already moved. Just inform the listener.
545  if (auto *listener = rewriter.getListener()) {
546  // Note: `previousIt` cannot be passed because this is a delayed
547  // notification and iterators into past IR state cannot be represented.
548  listener->notifyOperationInserted(
549  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
550  /*insertPt=*/{}));
551  }
552  }
553 
554  void rollback() override {
555  // Move the operation back to its original position.
556  Block::iterator before =
557  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
558  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
559  }
560 
561 private:
562  // The block in which this operation was previously contained.
563  Block *block;
564 
565  // The original successor of this operation before it was moved. "nullptr"
566  // if this operation was the only operation in the region.
567  Operation *insertBeforeOp;
568 };
569 
570 /// In-place modification of an op. This rewrite is immediately reflected in
571 /// the IR. The previous state of the operation is stored in this object.
572 class ModifyOperationRewrite : public OperationRewrite {
573 public:
574  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
575  Operation *op)
576  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
577  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
578  operands(op->operand_begin(), op->operand_end()),
579  successors(op->successor_begin(), op->successor_end()) {
580  if (OpaqueProperties prop = op->getPropertiesStorage()) {
581  // Make a copy of the properties.
582  propertiesStorage = operator new(op->getPropertiesStorageSize());
583  OpaqueProperties propCopy(propertiesStorage);
584  name.initOpProperties(propCopy, /*init=*/prop);
585  }
586  }
587 
588  static bool classof(const IRRewrite *rewrite) {
589  return rewrite->getKind() == Kind::ModifyOperation;
590  }
591 
592  ~ModifyOperationRewrite() override {
593  assert(!propertiesStorage &&
594  "rewrite was neither committed nor rolled back");
595  }
596 
597  void commit(RewriterBase &rewriter) override {
598  // Notify the listener that the operation was modified in-place.
599  if (auto *listener =
600  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
601  listener->notifyOperationModified(op);
602 
603  if (propertiesStorage) {
604  OpaqueProperties propCopy(propertiesStorage);
605  // Note: The operation may have been erased in the mean time, so
606  // OperationName must be stored in this object.
607  name.destroyOpProperties(propCopy);
608  operator delete(propertiesStorage);
609  propertiesStorage = nullptr;
610  }
611  }
612 
613  void rollback() override {
614  op->setLoc(loc);
615  op->setAttrs(attrs);
616  op->setOperands(operands);
617  for (const auto &it : llvm::enumerate(successors))
618  op->setSuccessor(it.value(), it.index());
619  if (propertiesStorage) {
620  OpaqueProperties propCopy(propertiesStorage);
621  op->copyProperties(propCopy);
622  name.destroyOpProperties(propCopy);
623  operator delete(propertiesStorage);
624  propertiesStorage = nullptr;
625  }
626  }
627 
628 private:
629  OperationName name;
630  LocationAttr loc;
631  DictionaryAttr attrs;
632  SmallVector<Value, 8> operands;
633  SmallVector<Block *, 2> successors;
634  void *propertiesStorage = nullptr;
635 };
636 
637 /// Replacing an operation. Erasing an operation is treated as a special case
638 /// with "null" replacements. This rewrite is not immediately reflected in the
639 /// IR. An internal IR mapping is updated, but values are not replaced and the
640 /// original op is not erased until the rewrite is committed.
641 class ReplaceOperationRewrite : public OperationRewrite {
642 public:
643  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
644  Operation *op, const TypeConverter *converter,
645  bool changedResults)
646  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
647  converter(converter), changedResults(changedResults) {}
648 
649  static bool classof(const IRRewrite *rewrite) {
650  return rewrite->getKind() == Kind::ReplaceOperation;
651  }
652 
653  void commit(RewriterBase &rewriter) override;
654 
655  void rollback() override;
656 
657  void cleanup(RewriterBase &rewriter) override;
658 
659  const TypeConverter *getConverter() const { return converter; }
660 
661  bool hasChangedResults() const { return changedResults; }
662 
663 private:
664  /// An optional type converter that can be used to materialize conversions
665  /// between the new and old values if necessary.
666  const TypeConverter *converter;
667 
668  /// A boolean flag that indicates whether result types have changed or not.
669  bool changedResults;
670 };
671 
672 class CreateOperationRewrite : public OperationRewrite {
673 public:
674  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
675  Operation *op)
676  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
677 
678  static bool classof(const IRRewrite *rewrite) {
679  return rewrite->getKind() == Kind::CreateOperation;
680  }
681 
682  void commit(RewriterBase &rewriter) override {
683  // The operation was already created and inserted. Just inform the listener.
684  if (auto *listener = rewriter.getListener())
685  listener->notifyOperationInserted(op, /*previous=*/{});
686  }
687 
688  void rollback() override;
689 };
690 
691 /// The type of materialization.
692 enum MaterializationKind {
693  /// This materialization materializes a conversion for an illegal block
694  /// argument type, to a legal one.
695  Argument,
696 
697  /// This materialization materializes a conversion from an illegal type to a
698  /// legal one.
699  Target
700 };
701 
702 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
703 /// op. Unresolved materializations are erased at the end of the dialect
704 /// conversion.
705 class UnresolvedMaterializationRewrite : public OperationRewrite {
706 public:
707  UnresolvedMaterializationRewrite(
708  ConversionPatternRewriterImpl &rewriterImpl,
709  UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
710  MaterializationKind kind = MaterializationKind::Target,
711  Type origOutputType = nullptr)
712  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
713  converterAndKind(converter, kind), origOutputType(origOutputType) {}
714 
715  static bool classof(const IRRewrite *rewrite) {
716  return rewrite->getKind() == Kind::UnresolvedMaterialization;
717  }
718 
719  UnrealizedConversionCastOp getOperation() const {
720  return cast<UnrealizedConversionCastOp>(op);
721  }
722 
723  void rollback() override;
724 
725  void cleanup(RewriterBase &rewriter) override;
726 
727  /// Return the type converter of this materialization (which may be null).
728  const TypeConverter *getConverter() const {
729  return converterAndKind.getPointer();
730  }
731 
732  /// Return the kind of this materialization.
733  MaterializationKind getMaterializationKind() const {
734  return converterAndKind.getInt();
735  }
736 
737  /// Set the kind of this materialization.
738  void setMaterializationKind(MaterializationKind kind) {
739  converterAndKind.setInt(kind);
740  }
741 
742  /// Return the original illegal output type of the input values.
743  Type getOrigOutputType() const { return origOutputType; }
744 
745 private:
746  /// The corresponding type converter to use when resolving this
747  /// materialization, and the kind of this materialization.
748  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
749  converterAndKind;
750 
751  /// The original output type. This is only used for argument conversions.
752  Type origOutputType;
753 };
754 } // namespace
755 
756 /// Return "true" if there is an operation rewrite that matches the specified
757 /// rewrite type and operation among the given rewrites.
758 template <typename RewriteTy, typename R>
759 static bool hasRewrite(R &&rewrites, Operation *op) {
760  return any_of(std::move(rewrites), [&](auto &rewrite) {
761  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
762  return rewriteTy && rewriteTy->getOperation() == op;
763  });
764 }
765 
766 /// Find the single rewrite object of the specified type and block among the
767 /// given rewrites. In debug mode, asserts that there is mo more than one such
768 /// object. Return "nullptr" if no object was found.
769 template <typename RewriteTy, typename R>
770 static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
771  RewriteTy *result = nullptr;
772  for (auto &rewrite : rewrites) {
773  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
774  if (rewriteTy && rewriteTy->getBlock() == block) {
775 #ifndef NDEBUG
776  assert(!result && "expected single matching rewrite");
777  result = rewriteTy;
778 #else
779  return rewriteTy;
780 #endif // NDEBUG
781  }
782  }
783  return result;
784 }
785 
786 //===----------------------------------------------------------------------===//
787 // ConversionPatternRewriterImpl
788 //===----------------------------------------------------------------------===//
789 namespace mlir {
790 namespace detail {
793  const ConversionConfig &config)
794  : context(ctx), config(config) {}
795 
796  //===--------------------------------------------------------------------===//
797  // State Management
798  //===--------------------------------------------------------------------===//
799 
800  /// Return the current state of the rewriter.
801  RewriterState getCurrentState();
802 
803  /// Apply all requested operation rewrites. This method is invoked when the
804  /// conversion process succeeds.
805  void applyRewrites();
806 
807  /// Reset the state of the rewriter to a previously saved point.
808  void resetState(RewriterState state);
809 
810  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
811  /// failure.
812  template <typename RewriteTy, typename... Args>
813  void appendRewrite(Args &&...args) {
814  rewrites.push_back(
815  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
816  }
817 
818  /// Undo the rewrites (motions, splits) one by one in reverse order until
819  /// "numRewritesToKeep" rewrites remains.
820  void undoRewrites(unsigned numRewritesToKeep = 0);
821 
822  /// Remap the given values to those with potentially different types. Returns
823  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
824  /// is the tag used when describing a value within a diagnostic, e.g.
825  /// "operand".
826  LogicalResult remapValues(StringRef valueDiagTag,
827  std::optional<Location> inputLoc,
828  PatternRewriter &rewriter, ValueRange values,
829  SmallVectorImpl<Value> &remapped);
830 
831  /// Return "true" if the given operation is ignored, and does not need to be
832  /// converted.
833  bool isOpIgnored(Operation *op) const;
834 
835  /// Return "true" if the given operation was replaced or erased.
836  bool wasOpReplaced(Operation *op) const;
837 
838  //===--------------------------------------------------------------------===//
839  // Type Conversion
840  //===--------------------------------------------------------------------===//
841 
842  /// Attempt to convert the signature of the given block, if successful a new
843  /// block is returned containing the new arguments. Returns `block` if it did
844  /// not require conversion.
845  FailureOr<Block *> convertBlockSignature(
846  ConversionPatternRewriter &rewriter, Block *block,
847  const TypeConverter *converter,
848  TypeConverter::SignatureConversion *conversion = nullptr);
849 
850  /// Convert the types of non-entry block arguments within the given region.
851  LogicalResult convertNonEntryRegionTypes(
852  ConversionPatternRewriter &rewriter, Region *region,
853  const TypeConverter &converter,
854  ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
855 
856  /// Apply a signature conversion on the given region, using `converter` for
857  /// materializations if not null.
858  Block *
859  applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
861  const TypeConverter *converter);
862 
863  /// Convert the types of block arguments within the given region.
865  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
866  const TypeConverter &converter,
867  TypeConverter::SignatureConversion *entryConversion);
868 
869  /// Apply the given signature conversion on the given block. The new block
870  /// containing the updated signature is returned. If no conversions were
871  /// necessary, e.g. if the block has no arguments, `block` is returned.
872  /// `converter` is used to generate any necessary cast operations that
873  /// translate between the origin argument types and those specified in the
874  /// signature conversion.
875  Block *applySignatureConversion(
876  ConversionPatternRewriter &rewriter, Block *block,
877  const TypeConverter *converter,
878  TypeConverter::SignatureConversion &signatureConversion);
879 
880  //===--------------------------------------------------------------------===//
881  // Materializations
882  //===--------------------------------------------------------------------===//
883  /// Build an unresolved materialization operation given an output type and set
884  /// of input operands.
885  Value buildUnresolvedMaterialization(MaterializationKind kind,
886  Block *insertBlock,
887  Block::iterator insertPt, Location loc,
888  ValueRange inputs, Type outputType,
889  Type origOutputType,
890  const TypeConverter *converter);
891 
892  Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
893  ValueRange inputs,
894  Type origOutputType,
895  Type outputType,
896  const TypeConverter *converter);
897 
898  Value buildUnresolvedTargetMaterialization(Location loc, Value input,
899  Type outputType,
900  const TypeConverter *converter);
901 
902  //===--------------------------------------------------------------------===//
903  // Rewriter Notification Hooks
904  //===--------------------------------------------------------------------===//
905 
906  //// Notifies that an op was inserted.
907  void notifyOperationInserted(Operation *op,
908  OpBuilder::InsertPoint previous) override;
909 
910  /// Notifies that an op is about to be replaced with the given values.
911  void notifyOpReplaced(Operation *op, ValueRange newValues);
912 
913  /// Notifies that a block is about to be erased.
914  void notifyBlockIsBeingErased(Block *block);
915 
916  /// Notifies that a block was inserted.
917  void notifyBlockInserted(Block *block, Region *previous,
918  Region::iterator previousIt) override;
919 
920  /// Notifies that a block is being inlined into another block.
921  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
922  Block::iterator before);
923 
924  /// Notifies that a pattern match failed for the given reason.
925  void
926  notifyMatchFailure(Location loc,
927  function_ref<void(Diagnostic &)> reasonCallback) override;
928 
929  //===--------------------------------------------------------------------===//
930  // IR Erasure
931  //===--------------------------------------------------------------------===//
932 
933  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
934  /// operation or block is erased multiple times. This rewriter assumes that
935  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
937  public:
939  : RewriterBase(context, /*listener=*/this) {}
940 
941  /// Erase the given op (unless it was already erased).
942  void eraseOp(Operation *op) override {
943  if (erased.contains(op))
944  return;
945  op->dropAllUses();
947  }
948 
949  /// Erase the given block (unless it was already erased).
950  void eraseBlock(Block *block) override {
951  if (erased.contains(block))
952  return;
953  assert(block->empty() && "expected empty block");
954  block->dropAllDefinedValueUses();
956  }
957 
958  void notifyOperationErased(Operation *op) override { erased.insert(op); }
959 
960  void notifyBlockErased(Block *block) override { erased.insert(block); }
961 
962  /// Pointers to all erased operations and blocks.
964  };
965 
966  //===--------------------------------------------------------------------===//
967  // State
968  //===--------------------------------------------------------------------===//
969 
970  /// MLIR context.
972 
973  // Mapping between replaced values that differ in type. This happens when
974  // replacing a value with one of a different type.
975  ConversionValueMapping mapping;
976 
977  /// Ordered list of block operations (creations, splits, motions).
979 
980  /// A set of operations that should no longer be considered for legalization.
981  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
982  /// tracked separately.
984 
985  /// A set of operations that were replaced/erased. Such ops are not erased
986  /// immediately but only when the dialect conversion succeeds. In the mean
987  /// time, they should no longer be considered for legalization and any attempt
988  /// to modify/access them is invalid rewriter API usage.
990 
991  /// The current type converter, or nullptr if no type converter is currently
992  /// active.
993  const TypeConverter *currentTypeConverter = nullptr;
994 
995  /// A mapping of regions to type converters that should be used when
996  /// converting the arguments of blocks within that region.
998 
999  /// Dialect conversion configuration.
1001 
1002 #ifndef NDEBUG
1003  /// A set of operations that have pending updates. This tracking isn't
1004  /// strictly necessary, and is thus only active during debug builds for extra
1005  /// verification.
1007 
1008  /// A logger used to emit diagnostics during the conversion process.
1009  llvm::ScopedPrinter logger{llvm::dbgs()};
1010 #endif
1011 };
1012 } // namespace detail
1013 } // namespace mlir
1014 
1015 const ConversionConfig &IRRewrite::getConfig() const {
1016  return rewriterImpl.config;
1017 }
1018 
1019 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1020  // Inform the listener about all IR modifications that have already taken
1021  // place: References to the original block have been replaced with the new
1022  // block.
1023  if (auto *listener =
1024  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1025  for (Operation *op : block->getUsers())
1026  listener->notifyOperationModified(op);
1027 
1028  // Process the remapping for each of the original arguments.
1029  for (auto [origArg, info] :
1030  llvm::zip_equal(origBlock->getArguments(), argInfo)) {
1031  // Handle the case of a 1->0 value mapping.
1032  if (!info) {
1033  if (Value newArg =
1034  rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
1035  rewriter.replaceAllUsesWith(origArg, newArg);
1036  continue;
1037  }
1038 
1039  // Otherwise this is a 1->1+ value mapping.
1040  Value castValue = info->castValue;
1041  assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
1042 
1043  // If the argument is still used, replace it with the generated cast.
1044  if (!origArg.use_empty()) {
1045  rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
1046  castValue, origArg.getType()));
1047  }
1048  }
1049 }
1050 
1051 void BlockTypeConversionRewrite::rollback() {
1052  block->replaceAllUsesWith(origBlock);
1053 }
1054 
1055 LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1056  function_ref<Operation *(Value)> findLiveUser) {
1057  // Process the remapping for each of the original arguments.
1058  for (auto it : llvm::enumerate(origBlock->getArguments())) {
1059  BlockArgument origArg = it.value();
1060  // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
1061  OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
1062  builder.setInsertionPointToStart(block);
1063 
1064  // If the type of this argument changed and the argument is still live, we
1065  // need to materialize a conversion.
1066  if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
1067  continue;
1068  Operation *liveUser = findLiveUser(origArg);
1069  if (!liveUser)
1070  continue;
1071 
1072  Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
1073  bool isDroppedArg = replacementValue == origArg;
1074  if (!isDroppedArg)
1075  builder.setInsertionPointAfterValue(replacementValue);
1076  Value newArg;
1077  if (converter) {
1078  newArg = converter->materializeSourceConversion(
1079  builder, origArg.getLoc(), origArg.getType(),
1080  isDroppedArg ? ValueRange() : ValueRange(replacementValue));
1081  assert((!newArg || newArg.getType() == origArg.getType()) &&
1082  "materialization hook did not provide a value of the expected "
1083  "type");
1084  }
1085  if (!newArg) {
1087  emitError(origArg.getLoc())
1088  << "failed to materialize conversion for block argument #"
1089  << it.index() << " that remained live after conversion, type was "
1090  << origArg.getType();
1091  if (!isDroppedArg)
1092  diag << ", with target type " << replacementValue.getType();
1093  diag.attachNote(liveUser->getLoc())
1094  << "see existing live user here: " << *liveUser;
1095  return failure();
1096  }
1097  rewriterImpl.mapping.map(origArg, newArg);
1098  }
1099  return success();
1100 }
1101 
1102 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1103  Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
1104  if (!repl)
1105  return;
1106 
1107  if (isa<BlockArgument>(repl)) {
1108  rewriter.replaceAllUsesWith(arg, repl);
1109  return;
1110  }
1111 
1112  // If the replacement value is an operation, we check to make sure that we
1113  // don't replace uses that are within the parent operation of the
1114  // replacement value.
1115  Operation *replOp = cast<OpResult>(repl).getOwner();
1116  Block *replBlock = replOp->getBlock();
1117  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1118  Operation *user = operand.getOwner();
1119  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1120  });
1121 }
1122 
1123 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
1124 
1125 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1126  auto *listener =
1127  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1128 
1129  // Compute replacement values.
1130  SmallVector<Value> replacements =
1131  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1132  return rewriterImpl.mapping.lookupOrNull(result, result.getType());
1133  });
1134 
1135  // Notify the listener that the operation is about to be replaced.
1136  if (listener)
1137  listener->notifyOperationReplaced(op, replacements);
1138 
1139  // Replace all uses with the new values.
1140  for (auto [result, newValue] :
1141  llvm::zip_equal(op->getResults(), replacements))
1142  if (newValue)
1143  rewriter.replaceAllUsesWith(result, newValue);
1144 
1145  // The original op will be erased, so remove it from the set of unlegalized
1146  // ops.
1147  if (getConfig().unlegalizedOps)
1148  getConfig().unlegalizedOps->erase(op);
1149 
1150  // Notify the listener that the operation (and its nested operations) was
1151  // erased.
1152  if (listener) {
1154  [&](Operation *op) { listener->notifyOperationErased(op); });
1155  }
1156 
1157  // Do not erase the operation yet. It may still be referenced in `mapping`.
1158  // Just unlink it for now and erase it during cleanup.
1159  op->getBlock()->getOperations().remove(op);
1160 }
1161 
1162 void ReplaceOperationRewrite::rollback() {
1163  for (auto result : op->getResults())
1164  rewriterImpl.mapping.erase(result);
1165 }
1166 
1167 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1168  rewriter.eraseOp(op);
1169 }
1170 
1171 void CreateOperationRewrite::rollback() {
1172  for (Region &region : op->getRegions()) {
1173  while (!region.getBlocks().empty())
1174  region.getBlocks().remove(region.getBlocks().begin());
1175  }
1176  op->dropAllUses();
1177  op->erase();
1178 }
1179 
1180 void UnresolvedMaterializationRewrite::rollback() {
1181  if (getMaterializationKind() == MaterializationKind::Target) {
1182  for (Value input : op->getOperands())
1183  rewriterImpl.mapping.erase(input);
1184  }
1185  op->erase();
1186 }
1187 
1188 void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
1189  rewriter.eraseOp(op);
1190 }
1191 
1193  // Commit all rewrites.
1194  IRRewriter rewriter(context, config.listener);
1195  for (auto &rewrite : rewrites)
1196  rewrite->commit(rewriter);
1197 
1198  // Clean up all rewrites.
1199  SingleEraseRewriter eraseRewriter(context);
1200  for (auto &rewrite : rewrites)
1201  rewrite->cleanup(eraseRewriter);
1202 }
1203 
1204 //===----------------------------------------------------------------------===//
1205 // State Management
1206 
1208  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1209 }
1210 
1212  // Undo any rewrites.
1213  undoRewrites(state.numRewrites);
1214 
1215  // Pop all of the recorded ignored operations that are no longer valid.
1216  while (ignoredOps.size() != state.numIgnoredOperations)
1217  ignoredOps.pop_back();
1218 
1219  while (replacedOps.size() != state.numReplacedOps)
1220  replacedOps.pop_back();
1221 }
1222 
1223 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1224  for (auto &rewrite :
1225  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1226  rewrite->rollback();
1227  rewrites.resize(numRewritesToKeep);
1228 }
1229 
1231  StringRef valueDiagTag, std::optional<Location> inputLoc,
1232  PatternRewriter &rewriter, ValueRange values,
1233  SmallVectorImpl<Value> &remapped) {
1234  remapped.reserve(llvm::size(values));
1235 
1236  SmallVector<Type, 1> legalTypes;
1237  for (const auto &it : llvm::enumerate(values)) {
1238  Value operand = it.value();
1239  Type origType = operand.getType();
1240 
1241  // If a converter was provided, get the desired legal types for this
1242  // operand.
1243  Type desiredType;
1244  if (currentTypeConverter) {
1245  // If there is no legal conversion, fail to match this pattern.
1246  legalTypes.clear();
1247  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1248  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1249  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1250  diag << "unable to convert type for " << valueDiagTag << " #"
1251  << it.index() << ", type was " << origType;
1252  });
1253  return failure();
1254  }
1255  // TODO: There currently isn't any mechanism to do 1->N type conversion
1256  // via the PatternRewriter replacement API, so for now we just ignore it.
1257  if (legalTypes.size() == 1)
1258  desiredType = legalTypes.front();
1259  } else {
1260  // TODO: What we should do here is just set `desiredType` to `origType`
1261  // and then handle the necessary type conversions after the conversion
1262  // process has finished. Unfortunately a lot of patterns currently rely on
1263  // receiving the new operands even if the types change, so we keep the
1264  // original behavior here for now until all of the patterns relying on
1265  // this get updated.
1266  }
1267  Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1268 
1269  // Handle the case where the conversion was 1->1 and the new operand type
1270  // isn't legal.
1271  Type newOperandType = newOperand.getType();
1272  if (currentTypeConverter && desiredType && newOperandType != desiredType) {
1273  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1275  operandLoc, newOperand, desiredType, currentTypeConverter);
1276  mapping.map(mapping.lookupOrDefault(newOperand), castValue);
1277  newOperand = castValue;
1278  }
1279  remapped.push_back(newOperand);
1280  }
1281  return success();
1282 }
1283 
1285  // Check to see if this operation is ignored or was replaced.
1286  return replacedOps.count(op) || ignoredOps.count(op);
1287 }
1288 
1290  // Check to see if this operation was replaced.
1291  return replacedOps.count(op);
1292 }
1293 
1294 //===----------------------------------------------------------------------===//
1295 // Type Conversion
1296 
1298  ConversionPatternRewriter &rewriter, Block *block,
1299  const TypeConverter *converter,
1300  TypeConverter::SignatureConversion *conversion) {
1301  if (conversion)
1302  return applySignatureConversion(rewriter, block, converter, *conversion);
1303 
1304  // If a converter wasn't provided, and the block wasn't already converted,
1305  // there is nothing we can do.
1306  if (!converter)
1307  return failure();
1308 
1309  // Try to convert the signature for the block with the provided converter.
1310  if (auto conversion = converter->convertBlockSignature(block))
1311  return applySignatureConversion(rewriter, block, converter, *conversion);
1312  return failure();
1313 }
1314 
1316  ConversionPatternRewriter &rewriter, Region *region,
1318  const TypeConverter *converter) {
1319  if (!region->empty())
1320  return *convertBlockSignature(rewriter, &region->front(), converter,
1321  &conversion);
1322  return nullptr;
1323 }
1324 
1326  ConversionPatternRewriter &rewriter, Region *region,
1327  const TypeConverter &converter,
1328  TypeConverter::SignatureConversion *entryConversion) {
1329  regionToConverter[region] = &converter;
1330  if (region->empty())
1331  return nullptr;
1332 
1333  if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
1334  return failure();
1335 
1337  rewriter, &region->front(), &converter, entryConversion);
1338  return newEntry;
1339 }
1340 
1342  ConversionPatternRewriter &rewriter, Region *region,
1343  const TypeConverter &converter,
1345  regionToConverter[region] = &converter;
1346  if (region->empty())
1347  return success();
1348 
1349  // Convert the arguments of each block within the region.
1350  int blockIdx = 0;
1351  assert((blockConversions.empty() ||
1352  blockConversions.size() == region->getBlocks().size() - 1) &&
1353  "expected either to provide no SignatureConversions at all or to "
1354  "provide a SignatureConversion for each non-entry block");
1355 
1356  for (Block &block :
1357  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1358  TypeConverter::SignatureConversion *blockConversion =
1359  blockConversions.empty()
1360  ? nullptr
1361  : const_cast<TypeConverter::SignatureConversion *>(
1362  &blockConversions[blockIdx++]);
1363 
1364  if (failed(convertBlockSignature(rewriter, &block, &converter,
1365  blockConversion)))
1366  return failure();
1367  }
1368  return success();
1369 }
1370 
1372  ConversionPatternRewriter &rewriter, Block *block,
1373  const TypeConverter *converter,
1374  TypeConverter::SignatureConversion &signatureConversion) {
1375  OpBuilder::InsertionGuard g(rewriter);
1376 
1377  // If no arguments are being changed or added, there is nothing to do.
1378  unsigned origArgCount = block->getNumArguments();
1379  auto convertedTypes = signatureConversion.getConvertedTypes();
1380  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1381  return block;
1382 
1383  // Compute the locations of all block arguments in the new block.
1384  SmallVector<Location> newLocs(convertedTypes.size(),
1385  rewriter.getUnknownLoc());
1386  for (unsigned i = 0; i < origArgCount; ++i) {
1387  auto inputMap = signatureConversion.getInputMapping(i);
1388  if (!inputMap || inputMap->replacementValue)
1389  continue;
1390  Location origLoc = block->getArgument(i).getLoc();
1391  for (unsigned j = 0; j < inputMap->size; ++j)
1392  newLocs[inputMap->inputNo + j] = origLoc;
1393  }
1394 
1395  // Insert a new block with the converted block argument types and move all ops
1396  // from the old block to the new block.
1397  Block *newBlock =
1398  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1399  convertedTypes, newLocs);
1400 
1401  // If a listener is attached to the dialect conversion, ops cannot be moved
1402  // to the destination block in bulk ("fast path"). This is because at the time
1403  // the notifications are sent, it is unknown which ops were moved. Instead,
1404  // ops should be moved one-by-one ("slow path"), so that a separate
1405  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1406  // a bit more efficient, so we try to do that when possible.
1407  bool fastPath = !config.listener;
1408  if (fastPath) {
1409  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1410  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1411  } else {
1412  while (!block->empty())
1413  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1414  }
1415 
1416  // Replace all uses of the old block with the new block.
1417  block->replaceAllUsesWith(newBlock);
1418 
1419  // Remap each of the original arguments as determined by the signature
1420  // conversion.
1422  argInfo.resize(origArgCount);
1423 
1424  for (unsigned i = 0; i != origArgCount; ++i) {
1425  auto inputMap = signatureConversion.getInputMapping(i);
1426  if (!inputMap)
1427  continue;
1428  BlockArgument origArg = block->getArgument(i);
1429 
1430  // If inputMap->replacementValue is not nullptr, then the argument is
1431  // dropped and a replacement value is provided to be the remappedValue.
1432  if (inputMap->replacementValue) {
1433  assert(inputMap->size == 0 &&
1434  "invalid to provide a replacement value when the argument isn't "
1435  "dropped");
1436  mapping.map(origArg, inputMap->replacementValue);
1437  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1438  continue;
1439  }
1440 
1441  // Otherwise, this is a 1->1+ mapping.
1442  auto replArgs =
1443  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1444  Value newArg;
1445 
1446  // If this is a 1->1 mapping and the types of new and replacement arguments
1447  // match (i.e. it's an identity map), then the argument is mapped to its
1448  // original type.
1449  // FIXME: We simply pass through the replacement argument if there wasn't a
1450  // converter, which isn't great as it allows implicit type conversions to
1451  // appear. We should properly restructure this code to handle cases where a
1452  // converter isn't provided and also to properly handle the case where an
1453  // argument materialization is actually a temporary source materialization
1454  // (e.g. in the case of 1->N).
1455  if (replArgs.size() == 1 &&
1456  (!converter || replArgs[0].getType() == origArg.getType())) {
1457  newArg = replArgs.front();
1458  } else {
1459  Type origOutputType = origArg.getType();
1460 
1461  // Legalize the argument output type.
1462  Type outputType = origOutputType;
1463  if (Type legalOutputType = converter->convertType(outputType))
1464  outputType = legalOutputType;
1465 
1467  newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
1468  converter);
1469  }
1470 
1471  mapping.map(origArg, newArg);
1472  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1473  argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
1474  }
1475 
1476  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1477  converter);
1478 
1479  // Erase the old block. (It is just unlinked for now and will be erased during
1480  // cleanup.)
1481  rewriter.eraseBlock(block);
1482 
1483  return newBlock;
1484 }
1485 
1486 //===----------------------------------------------------------------------===//
1487 // Materializations
1488 //===----------------------------------------------------------------------===//
1489 
1490 /// Build an unresolved materialization operation given an output type and set
1491 /// of input operands.
1493  MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1494  Location loc, ValueRange inputs, Type outputType, Type origOutputType,
1495  const TypeConverter *converter) {
1496  // Avoid materializing an unnecessary cast.
1497  if (inputs.size() == 1 && inputs.front().getType() == outputType)
1498  return inputs.front();
1499 
1500  // Create an unresolved materialization. We use a new OpBuilder to avoid
1501  // tracking the materialization like we do for other operations.
1502  OpBuilder builder(insertBlock, insertPt);
1503  auto convertOp =
1504  builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1505  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1506  origOutputType);
1507  return convertOp.getResult(0);
1508 }
1510  Block *block, Location loc, ValueRange inputs, Type origOutputType,
1511  Type outputType, const TypeConverter *converter) {
1513  block->begin(), loc, inputs, outputType,
1514  origOutputType, converter);
1515 }
1517  Location loc, Value input, Type outputType,
1518  const TypeConverter *converter) {
1519  Block *insertBlock = input.getParentBlock();
1520  Block::iterator insertPt = insertBlock->begin();
1521  if (OpResult inputRes = dyn_cast<OpResult>(input))
1522  insertPt = ++inputRes.getOwner()->getIterator();
1523 
1524  return buildUnresolvedMaterialization(MaterializationKind::Target,
1525  insertBlock, insertPt, loc, input,
1526  outputType, outputType, converter);
1527 }
1528 
1529 //===----------------------------------------------------------------------===//
1530 // Rewriter Notification Hooks
1531 
1533  Operation *op, OpBuilder::InsertPoint previous) {
1534  LLVM_DEBUG({
1535  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1536  << ")\n";
1537  });
1538  assert(!wasOpReplaced(op->getParentOp()) &&
1539  "attempting to insert into a block within a replaced/erased op");
1540 
1541  if (!previous.isSet()) {
1542  // This is a newly created op.
1543  appendRewrite<CreateOperationRewrite>(op);
1544  return;
1545  }
1546  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1547  ? nullptr
1548  : &*previous.getPoint();
1549  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1550 }
1551 
1553  ValueRange newValues) {
1554  assert(newValues.size() == op->getNumResults());
1555  assert(!ignoredOps.contains(op) && "operation was already replaced");
1556 
1557  // Track if any of the results changed, e.g. erased and replaced with null.
1558  bool resultChanged = false;
1559 
1560  // Create mappings for each of the new result values.
1561  for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
1562  if (!newValue) {
1563  resultChanged = true;
1564  continue;
1565  }
1566  // Remap, and check for any result type changes.
1567  mapping.map(result, newValue);
1568  resultChanged |= (newValue.getType() != result.getType());
1569  }
1570 
1571  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
1572  resultChanged);
1573 
1574  // Mark this operation and all nested ops as replaced.
1575  op->walk([&](Operation *op) { replacedOps.insert(op); });
1576 }
1577 
1579  Region *region = block->getParent();
1580  Block *origNextBlock = block->getNextNode();
1581  appendRewrite<EraseBlockRewrite>(block, region, origNextBlock);
1582 }
1583 
1585  Block *block, Region *previous, Region::iterator previousIt) {
1586  assert(!wasOpReplaced(block->getParentOp()) &&
1587  "attempting to insert into a region within a replaced/erased op");
1588  LLVM_DEBUG(
1589  {
1590  Operation *parent = block->getParentOp();
1591  if (parent) {
1592  logger.startLine() << "** Insert Block into : '" << parent->getName()
1593  << "'(" << parent << ")\n";
1594  } else {
1595  logger.startLine()
1596  << "** Insert Block into detached Region (nullptr parent op)'";
1597  }
1598  });
1599 
1600  if (!previous) {
1601  // This is a newly created block.
1602  appendRewrite<CreateBlockRewrite>(block);
1603  return;
1604  }
1605  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1606  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1607 }
1608 
1610  Block *block, Block *srcBlock, Block::iterator before) {
1611  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1612 }
1613 
1615  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1616  LLVM_DEBUG({
1618  reasonCallback(diag);
1619  logger.startLine() << "** Failure : " << diag.str() << "\n";
1620  if (config.notifyCallback)
1622  });
1623 }
1624 
1625 //===----------------------------------------------------------------------===//
1626 // ConversionPatternRewriter
1627 //===----------------------------------------------------------------------===//
1628 
1629 ConversionPatternRewriter::ConversionPatternRewriter(
1630  MLIRContext *ctx, const ConversionConfig &config)
1631  : PatternRewriter(ctx),
1632  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1633  setListener(impl.get());
1634 }
1635 
1637 
1639  assert(op && newOp && "expected non-null op");
1640  replaceOp(op, newOp->getResults());
1641 }
1642 
1644  assert(op->getNumResults() == newValues.size() &&
1645  "incorrect # of replacement values");
1646  LLVM_DEBUG({
1647  impl->logger.startLine()
1648  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1649  });
1650  impl->notifyOpReplaced(op, newValues);
1651 }
1652 
1654  LLVM_DEBUG({
1655  impl->logger.startLine()
1656  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1657  });
1658  SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1659  impl->notifyOpReplaced(op, nullRepls);
1660 }
1661 
1663  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1664  "attempting to erase a block within a replaced/erased op");
1665 
1666  // Mark all ops for erasure.
1667  for (Operation &op : *block)
1668  eraseOp(&op);
1669 
1670  // Unlink the block from its parent region. The block is kept in the rewrite
1671  // object and will be actually destroyed when rewrites are applied. This
1672  // allows us to keep the operations in the block live and undo the removal by
1673  // re-inserting the block.
1674  impl->notifyBlockIsBeingErased(block);
1675  block->getParent()->getBlocks().remove(block);
1676 }
1677 
1679  Region *region, TypeConverter::SignatureConversion &conversion,
1680  const TypeConverter *converter) {
1681  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1682  "attempting to apply a signature conversion to a block within a "
1683  "replaced/erased op");
1684  return impl->applySignatureConversion(*this, region, conversion, converter);
1685 }
1686 
1688  Region *region, const TypeConverter &converter,
1689  TypeConverter::SignatureConversion *entryConversion) {
1690  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1691  "attempting to apply a signature conversion to a block within a "
1692  "replaced/erased op");
1693  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1694 }
1695 
1697  Region *region, const TypeConverter &converter,
1699  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1700  "attempting to apply a signature conversion to a block within a "
1701  "replaced/erased op");
1702  return impl->convertNonEntryRegionTypes(*this, region, converter,
1703  blockConversions);
1704 }
1705 
1707  Value to) {
1708  LLVM_DEBUG({
1709  Operation *parentOp = from.getOwner()->getParentOp();
1710  impl->logger.startLine() << "** Replace Argument : '" << from
1711  << "'(in region of '" << parentOp->getName()
1712  << "'(" << from.getOwner()->getParentOp() << ")\n";
1713  });
1714  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
1715  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1716 }
1717 
1719  SmallVector<Value> remappedValues;
1720  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1721  remappedValues)))
1722  return nullptr;
1723  return remappedValues.front();
1724 }
1725 
1728  SmallVectorImpl<Value> &results) {
1729  if (keys.empty())
1730  return success();
1731  return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1732  results);
1733 }
1734 
1736  Block::iterator before,
1737  ValueRange argValues) {
1738 #ifndef NDEBUG
1739  assert(argValues.size() == source->getNumArguments() &&
1740  "incorrect # of argument replacement values");
1741  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1742  "attempting to inline a block from a replaced/erased op");
1743  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1744  "attempting to inline a block into a replaced/erased op");
1745  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1746  // The source block will be deleted, so it should not have any users (i.e.,
1747  // there should be no predecessors).
1748  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1749  "expected 'source' to have no predecessors");
1750 #endif // NDEBUG
1751 
1752  // If a listener is attached to the dialect conversion, ops cannot be moved
1753  // to the destination block in bulk ("fast path"). This is because at the time
1754  // the notifications are sent, it is unknown which ops were moved. Instead,
1755  // ops should be moved one-by-one ("slow path"), so that a separate
1756  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1757  // a bit more efficient, so we try to do that when possible.
1758  bool fastPath = !impl->config.listener;
1759 
1760  if (fastPath)
1761  impl->notifyBlockBeingInlined(dest, source, before);
1762 
1763  // Replace all uses of block arguments.
1764  for (auto it : llvm::zip(source->getArguments(), argValues))
1765  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1766 
1767  if (fastPath) {
1768  // Move all ops at once.
1769  dest->getOperations().splice(before, source->getOperations());
1770  } else {
1771  // Move op by op.
1772  while (!source->empty())
1773  moveOpBefore(&source->front(), dest, before);
1774  }
1775 
1776  // Erase the source block.
1777  eraseBlock(source);
1778 }
1779 
1781  assert(!impl->wasOpReplaced(op) &&
1782  "attempting to modify a replaced/erased op");
1783 #ifndef NDEBUG
1784  impl->pendingRootUpdates.insert(op);
1785 #endif
1786  impl->appendRewrite<ModifyOperationRewrite>(op);
1787 }
1788 
1790  assert(!impl->wasOpReplaced(op) &&
1791  "attempting to modify a replaced/erased op");
1793  // There is nothing to do here, we only need to track the operation at the
1794  // start of the update.
1795 #ifndef NDEBUG
1796  assert(impl->pendingRootUpdates.erase(op) &&
1797  "operation did not have a pending in-place update");
1798 #endif
1799 }
1800 
1802 #ifndef NDEBUG
1803  assert(impl->pendingRootUpdates.erase(op) &&
1804  "operation did not have a pending in-place update");
1805 #endif
1806  // Erase the last update for this operation.
1807  auto it = llvm::find_if(
1808  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1809  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1810  return modifyRewrite && modifyRewrite->getOperation() == op;
1811  });
1812  assert(it != impl->rewrites.rend() && "no root update started on op");
1813  (*it)->rollback();
1814  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1815  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1816 }
1817 
1819  return *impl;
1820 }
1821 
1822 //===----------------------------------------------------------------------===//
1823 // ConversionPattern
1824 //===----------------------------------------------------------------------===//
1825 
1828  PatternRewriter &rewriter) const {
1829  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1830  auto &rewriterImpl = dialectRewriter.getImpl();
1831 
1832  // Track the current conversion pattern type converter in the rewriter.
1833  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1834  getTypeConverter());
1835 
1836  // Remap the operands of the operation.
1837  SmallVector<Value, 4> operands;
1838  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1839  op->getOperands(), operands))) {
1840  return failure();
1841  }
1842  return matchAndRewrite(op, operands, dialectRewriter);
1843 }
1844 
1845 //===----------------------------------------------------------------------===//
1846 // OperationLegalizer
1847 //===----------------------------------------------------------------------===//
1848 
1849 namespace {
1850 /// A set of rewrite patterns that can be used to legalize a given operation.
1851 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1852 
1853 /// This class defines a recursive operation legalizer.
1854 class OperationLegalizer {
1855 public:
1856  using LegalizationAction = ConversionTarget::LegalizationAction;
1857 
1858  OperationLegalizer(const ConversionTarget &targetInfo,
1859  const FrozenRewritePatternSet &patterns,
1860  const ConversionConfig &config);
1861 
1862  /// Returns true if the given operation is known to be illegal on the target.
1863  bool isIllegal(Operation *op) const;
1864 
1865  /// Attempt to legalize the given operation. Returns success if the operation
1866  /// was legalized, failure otherwise.
1867  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1868 
1869  /// Returns the conversion target in use by the legalizer.
1870  const ConversionTarget &getTarget() { return target; }
1871 
1872 private:
1873  /// Attempt to legalize the given operation by folding it.
1874  LogicalResult legalizeWithFold(Operation *op,
1875  ConversionPatternRewriter &rewriter);
1876 
1877  /// Attempt to legalize the given operation by applying a pattern. Returns
1878  /// success if the operation was legalized, failure otherwise.
1879  LogicalResult legalizeWithPattern(Operation *op,
1880  ConversionPatternRewriter &rewriter);
1881 
1882  /// Return true if the given pattern may be applied to the given operation,
1883  /// false otherwise.
1884  bool canApplyPattern(Operation *op, const Pattern &pattern,
1885  ConversionPatternRewriter &rewriter);
1886 
1887  /// Legalize the resultant IR after successfully applying the given pattern.
1888  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1889  ConversionPatternRewriter &rewriter,
1890  RewriterState &curState);
1891 
1892  /// Legalizes the actions registered during the execution of a pattern.
1894  legalizePatternBlockRewrites(Operation *op,
1895  ConversionPatternRewriter &rewriter,
1897  RewriterState &state, RewriterState &newState);
1898  LogicalResult legalizePatternCreatedOperations(
1900  RewriterState &state, RewriterState &newState);
1901  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1903  RewriterState &state,
1904  RewriterState &newState);
1905 
1906  //===--------------------------------------------------------------------===//
1907  // Cost Model
1908  //===--------------------------------------------------------------------===//
1909 
1910  /// Build an optimistic legalization graph given the provided patterns. This
1911  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1912  /// patterns for operations that are not directly legal, but may be
1913  /// transitively legal for the current target given the provided patterns.
1914  void buildLegalizationGraph(
1915  LegalizationPatterns &anyOpLegalizerPatterns,
1917 
1918  /// Compute the benefit of each node within the computed legalization graph.
1919  /// This orders the patterns within 'legalizerPatterns' based upon two
1920  /// criteria:
1921  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1922  /// represent the more direct mapping to the target.
1923  /// 2) When comparing patterns with the same legalization depth, prefer the
1924  /// pattern with the highest PatternBenefit. This allows for users to
1925  /// prefer specific legalizations over others.
1926  void computeLegalizationGraphBenefit(
1927  LegalizationPatterns &anyOpLegalizerPatterns,
1929 
1930  /// Compute the legalization depth when legalizing an operation of the given
1931  /// type.
1932  unsigned computeOpLegalizationDepth(
1933  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1935 
1936  /// Apply the conversion cost model to the given set of patterns, and return
1937  /// the smallest legalization depth of any of the patterns. See
1938  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1939  unsigned applyCostModelToPatterns(
1940  LegalizationPatterns &patterns,
1941  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1943 
1944  /// The current set of patterns that have been applied.
1945  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1946 
1947  /// The legalization information provided by the target.
1948  const ConversionTarget &target;
1949 
1950  /// The pattern applicator to use for conversions.
1951  PatternApplicator applicator;
1952 
1953  /// Dialect conversion configuration.
1954  const ConversionConfig &config;
1955 };
1956 } // namespace
1957 
1958 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1959  const FrozenRewritePatternSet &patterns,
1960  const ConversionConfig &config)
1961  : target(targetInfo), applicator(patterns), config(config) {
1962  // The set of patterns that can be applied to illegal operations to transform
1963  // them into legal ones.
1965  LegalizationPatterns anyOpLegalizerPatterns;
1966 
1967  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1968  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1969 }
1970 
1971 bool OperationLegalizer::isIllegal(Operation *op) const {
1972  return target.isIllegal(op);
1973 }
1974 
1976 OperationLegalizer::legalize(Operation *op,
1977  ConversionPatternRewriter &rewriter) {
1978 #ifndef NDEBUG
1979  const char *logLineComment =
1980  "//===-------------------------------------------===//\n";
1981 
1982  auto &logger = rewriter.getImpl().logger;
1983 #endif
1984  LLVM_DEBUG({
1985  logger.getOStream() << "\n";
1986  logger.startLine() << logLineComment;
1987  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1988  << op << ") {\n";
1989  logger.indent();
1990 
1991  // If the operation has no regions, just print it here.
1992  if (op->getNumRegions() == 0) {
1993  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1994  logger.getOStream() << "\n\n";
1995  }
1996  });
1997 
1998  // Check if this operation is legal on the target.
1999  if (auto legalityInfo = target.isLegal(op)) {
2000  LLVM_DEBUG({
2001  logSuccess(
2002  logger, "operation marked legal by the target{0}",
2003  legalityInfo->isRecursivelyLegal
2004  ? "; NOTE: operation is recursively legal; skipping internals"
2005  : "");
2006  logger.startLine() << logLineComment;
2007  });
2008 
2009  // If this operation is recursively legal, mark its children as ignored so
2010  // that we don't consider them for legalization.
2011  if (legalityInfo->isRecursivelyLegal) {
2012  op->walk([&](Operation *nested) {
2013  if (op != nested)
2014  rewriter.getImpl().ignoredOps.insert(nested);
2015  });
2016  }
2017 
2018  return success();
2019  }
2020 
2021  // Check to see if the operation is ignored and doesn't need to be converted.
2022  if (rewriter.getImpl().isOpIgnored(op)) {
2023  LLVM_DEBUG({
2024  logSuccess(logger, "operation marked 'ignored' during conversion");
2025  logger.startLine() << logLineComment;
2026  });
2027  return success();
2028  }
2029 
2030  // If the operation isn't legal, try to fold it in-place.
2031  // TODO: Should we always try to do this, even if the op is
2032  // already legal?
2033  if (succeeded(legalizeWithFold(op, rewriter))) {
2034  LLVM_DEBUG({
2035  logSuccess(logger, "operation was folded");
2036  logger.startLine() << logLineComment;
2037  });
2038  return success();
2039  }
2040 
2041  // Otherwise, we need to apply a legalization pattern to this operation.
2042  if (succeeded(legalizeWithPattern(op, rewriter))) {
2043  LLVM_DEBUG({
2044  logSuccess(logger, "");
2045  logger.startLine() << logLineComment;
2046  });
2047  return success();
2048  }
2049 
2050  LLVM_DEBUG({
2051  logFailure(logger, "no matched legalization pattern");
2052  logger.startLine() << logLineComment;
2053  });
2054  return failure();
2055 }
2056 
2058 OperationLegalizer::legalizeWithFold(Operation *op,
2059  ConversionPatternRewriter &rewriter) {
2060  auto &rewriterImpl = rewriter.getImpl();
2061  RewriterState curState = rewriterImpl.getCurrentState();
2062 
2063  LLVM_DEBUG({
2064  rewriterImpl.logger.startLine() << "* Fold {\n";
2065  rewriterImpl.logger.indent();
2066  });
2067 
2068  // Try to fold the operation.
2069  SmallVector<Value, 2> replacementValues;
2070  rewriter.setInsertionPoint(op);
2071  if (failed(rewriter.tryFold(op, replacementValues))) {
2072  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2073  return failure();
2074  }
2075 
2076  // Insert a replacement for 'op' with the folded replacement values.
2077  rewriter.replaceOp(op, replacementValues);
2078 
2079  // Recursively legalize any new constant operations.
2080  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
2081  i != e; ++i) {
2082  auto *createOp =
2083  dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
2084  if (!createOp)
2085  continue;
2086  if (failed(legalize(createOp->getOperation(), rewriter))) {
2087  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2088  "failed to legalize generated constant '{0}'",
2089  createOp->getOperation()->getName()));
2090  rewriterImpl.resetState(curState);
2091  return failure();
2092  }
2093  }
2094 
2095  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2096  return success();
2097 }
2098 
2100 OperationLegalizer::legalizeWithPattern(Operation *op,
2101  ConversionPatternRewriter &rewriter) {
2102  auto &rewriterImpl = rewriter.getImpl();
2103 
2104  // Functor that returns if the given pattern may be applied.
2105  auto canApply = [&](const Pattern &pattern) {
2106  bool canApply = canApplyPattern(op, pattern, rewriter);
2107  if (canApply && config.listener)
2108  config.listener->notifyPatternBegin(pattern, op);
2109  return canApply;
2110  };
2111 
2112  // Functor that cleans up the rewriter state after a pattern failed to match.
2113  RewriterState curState = rewriterImpl.getCurrentState();
2114  auto onFailure = [&](const Pattern &pattern) {
2115  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2116  LLVM_DEBUG({
2117  logFailure(rewriterImpl.logger, "pattern failed to match");
2118  if (rewriterImpl.config.notifyCallback) {
2120  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2121  << "\" on op:\n"
2122  << *op;
2123  rewriterImpl.config.notifyCallback(diag);
2124  }
2125  });
2126  if (config.listener)
2127  config.listener->notifyPatternEnd(pattern, failure());
2128  rewriterImpl.resetState(curState);
2129  appliedPatterns.erase(&pattern);
2130  };
2131 
2132  // Functor that performs additional legalization when a pattern is
2133  // successfully applied.
2134  auto onSuccess = [&](const Pattern &pattern) {
2135  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2136  auto result = legalizePatternResult(op, pattern, rewriter, curState);
2137  appliedPatterns.erase(&pattern);
2138  if (failed(result))
2139  rewriterImpl.resetState(curState);
2140  if (config.listener)
2141  config.listener->notifyPatternEnd(pattern, result);
2142  return result;
2143  };
2144 
2145  // Try to match and rewrite a pattern on this operation.
2146  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2147  onSuccess);
2148 }
2149 
2150 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2151  ConversionPatternRewriter &rewriter) {
2152  LLVM_DEBUG({
2153  auto &os = rewriter.getImpl().logger;
2154  os.getOStream() << "\n";
2155  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2156  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2157  os.getOStream() << ")' {\n";
2158  os.indent();
2159  });
2160 
2161  // Ensure that we don't cycle by not allowing the same pattern to be
2162  // applied twice in the same recursion stack if it is not known to be safe.
2163  if (!pattern.hasBoundedRewriteRecursion() &&
2164  !appliedPatterns.insert(&pattern).second) {
2165  LLVM_DEBUG(
2166  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2167  return false;
2168  }
2169  return true;
2170 }
2171 
2173 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2174  ConversionPatternRewriter &rewriter,
2175  RewriterState &curState) {
2176  auto &impl = rewriter.getImpl();
2177 
2178 #ifndef NDEBUG
2179  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2180  // Check that the root was either replaced or updated in place.
2181  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2182  auto replacedRoot = [&] {
2183  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2184  };
2185  auto updatedRootInPlace = [&] {
2186  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2187  };
2188  assert((replacedRoot() || updatedRootInPlace()) &&
2189  "expected pattern to replace the root operation");
2190 #endif // NDEBUG
2191 
2192  // Legalize each of the actions registered during application.
2193  RewriterState newState = impl.getCurrentState();
2194  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2195  newState)) ||
2196  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2197  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2198  newState))) {
2199  return failure();
2200  }
2201 
2202  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2203  return success();
2204 }
2205 
2206 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2207  Operation *op, ConversionPatternRewriter &rewriter,
2208  ConversionPatternRewriterImpl &impl, RewriterState &state,
2209  RewriterState &newState) {
2210  SmallPtrSet<Operation *, 16> operationsToIgnore;
2211 
2212  // If the pattern moved or created any blocks, make sure the types of block
2213  // arguments get legalized.
2214  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2215  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2216  if (!rewrite)
2217  continue;
2218  Block *block = rewrite->getBlock();
2219  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2220  ReplaceBlockArgRewrite>(rewrite))
2221  continue;
2222  // Only check blocks outside of the current operation.
2223  Operation *parentOp = block->getParentOp();
2224  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2225  continue;
2226 
2227  // If the region of the block has a type converter, try to convert the block
2228  // directly.
2229  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2230  if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
2231  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2232  "block"));
2233  return failure();
2234  }
2235  continue;
2236  }
2237 
2238  // Otherwise, check that this operation isn't one generated by this pattern.
2239  // This is because we will attempt to legalize the parent operation, and
2240  // blocks in regions created by this pattern will already be legalized later
2241  // on. If we haven't built the set yet, build it now.
2242  if (operationsToIgnore.empty()) {
2243  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2244  ++i) {
2245  auto *createOp =
2246  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2247  if (!createOp)
2248  continue;
2249  operationsToIgnore.insert(createOp->getOperation());
2250  }
2251  }
2252 
2253  // If this operation should be considered for re-legalization, try it.
2254  if (operationsToIgnore.insert(parentOp).second &&
2255  failed(legalize(parentOp, rewriter))) {
2256  LLVM_DEBUG(logFailure(impl.logger,
2257  "operation '{0}'({1}) became illegal after rewrite",
2258  parentOp->getName(), parentOp));
2259  return failure();
2260  }
2261  }
2262  return success();
2263 }
2264 
2265 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2267  RewriterState &state, RewriterState &newState) {
2268  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2269  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2270  if (!createOp)
2271  continue;
2272  Operation *op = createOp->getOperation();
2273  if (failed(legalize(op, rewriter))) {
2274  LLVM_DEBUG(logFailure(impl.logger,
2275  "failed to legalize generated operation '{0}'({1})",
2276  op->getName(), op));
2277  return failure();
2278  }
2279  }
2280  return success();
2281 }
2282 
2283 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2285  RewriterState &state, RewriterState &newState) {
2286  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2287  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2288  if (!rewrite)
2289  continue;
2290  Operation *op = rewrite->getOperation();
2291  if (failed(legalize(op, rewriter))) {
2292  LLVM_DEBUG(logFailure(
2293  impl.logger, "failed to legalize operation updated in-place '{0}'",
2294  op->getName()));
2295  return failure();
2296  }
2297  }
2298  return success();
2299 }
2300 
2301 //===----------------------------------------------------------------------===//
2302 // Cost Model
2303 
2304 void OperationLegalizer::buildLegalizationGraph(
2305  LegalizationPatterns &anyOpLegalizerPatterns,
2306  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2307  // A mapping between an operation and a set of operations that can be used to
2308  // generate it.
2310  // A mapping between an operation and any currently invalid patterns it has.
2312  // A worklist of patterns to consider for legality.
2313  SetVector<const Pattern *> patternWorklist;
2314 
2315  // Build the mapping from operations to the parent ops that may generate them.
2316  applicator.walkAllPatterns([&](const Pattern &pattern) {
2317  std::optional<OperationName> root = pattern.getRootKind();
2318 
2319  // If the pattern has no specific root, we can't analyze the relationship
2320  // between the root op and generated operations. Given that, add all such
2321  // patterns to the legalization set.
2322  if (!root) {
2323  anyOpLegalizerPatterns.push_back(&pattern);
2324  return;
2325  }
2326 
2327  // Skip operations that are always known to be legal.
2328  if (target.getOpAction(*root) == LegalizationAction::Legal)
2329  return;
2330 
2331  // Add this pattern to the invalid set for the root op and record this root
2332  // as a parent for any generated operations.
2333  invalidPatterns[*root].insert(&pattern);
2334  for (auto op : pattern.getGeneratedOps())
2335  parentOps[op].insert(*root);
2336 
2337  // Add this pattern to the worklist.
2338  patternWorklist.insert(&pattern);
2339  });
2340 
2341  // If there are any patterns that don't have a specific root kind, we can't
2342  // make direct assumptions about what operations will never be legalized.
2343  // Note: Technically we could, but it would require an analysis that may
2344  // recurse into itself. It would be better to perform this kind of filtering
2345  // at a higher level than here anyways.
2346  if (!anyOpLegalizerPatterns.empty()) {
2347  for (const Pattern *pattern : patternWorklist)
2348  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2349  return;
2350  }
2351 
2352  while (!patternWorklist.empty()) {
2353  auto *pattern = patternWorklist.pop_back_val();
2354 
2355  // Check to see if any of the generated operations are invalid.
2356  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2357  std::optional<LegalizationAction> action = target.getOpAction(op);
2358  return !legalizerPatterns.count(op) &&
2359  (!action || action == LegalizationAction::Illegal);
2360  }))
2361  continue;
2362 
2363  // Otherwise, if all of the generated operation are valid, this op is now
2364  // legal so add all of the child patterns to the worklist.
2365  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2366  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2367 
2368  // Add any invalid patterns of the parent operations to see if they have now
2369  // become legal.
2370  for (auto op : parentOps[*pattern->getRootKind()])
2371  patternWorklist.set_union(invalidPatterns[op]);
2372  }
2373 }
2374 
2375 void OperationLegalizer::computeLegalizationGraphBenefit(
2376  LegalizationPatterns &anyOpLegalizerPatterns,
2377  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2378  // The smallest pattern depth, when legalizing an operation.
2379  DenseMap<OperationName, unsigned> minOpPatternDepth;
2380 
2381  // For each operation that is transitively legal, compute a cost for it.
2382  for (auto &opIt : legalizerPatterns)
2383  if (!minOpPatternDepth.count(opIt.first))
2384  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2385  legalizerPatterns);
2386 
2387  // Apply the cost model to the patterns that can match any operation. Those
2388  // with a specific operation type are already resolved when computing the op
2389  // legalization depth.
2390  if (!anyOpLegalizerPatterns.empty())
2391  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2392  legalizerPatterns);
2393 
2394  // Apply a cost model to the pattern applicator. We order patterns first by
2395  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2396  // decreasing benefit.
2397  applicator.applyCostModel([&](const Pattern &pattern) {
2398  ArrayRef<const Pattern *> orderedPatternList;
2399  if (std::optional<OperationName> rootName = pattern.getRootKind())
2400  orderedPatternList = legalizerPatterns[*rootName];
2401  else
2402  orderedPatternList = anyOpLegalizerPatterns;
2403 
2404  // If the pattern is not found, then it was removed and cannot be matched.
2405  auto *it = llvm::find(orderedPatternList, &pattern);
2406  if (it == orderedPatternList.end())
2408 
2409  // Patterns found earlier in the list have higher benefit.
2410  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2411  });
2412 }
2413 
2414 unsigned OperationLegalizer::computeOpLegalizationDepth(
2415  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2416  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2417  // Check for existing depth.
2418  auto depthIt = minOpPatternDepth.find(op);
2419  if (depthIt != minOpPatternDepth.end())
2420  return depthIt->second;
2421 
2422  // If a mapping for this operation does not exist, then this operation
2423  // is always legal. Return 0 as the depth for a directly legal operation.
2424  auto opPatternsIt = legalizerPatterns.find(op);
2425  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2426  return 0u;
2427 
2428  // Record this initial depth in case we encounter this op again when
2429  // recursively computing the depth.
2430  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2431 
2432  // Apply the cost model to the operation patterns, and update the minimum
2433  // depth.
2434  unsigned minDepth = applyCostModelToPatterns(
2435  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2436  minOpPatternDepth[op] = minDepth;
2437  return minDepth;
2438 }
2439 
2440 unsigned OperationLegalizer::applyCostModelToPatterns(
2441  LegalizationPatterns &patterns,
2442  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2443  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2444  unsigned minDepth = std::numeric_limits<unsigned>::max();
2445 
2446  // Compute the depth for each pattern within the set.
2448  patternsByDepth.reserve(patterns.size());
2449  for (const Pattern *pattern : patterns) {
2450  unsigned depth = 1;
2451  for (auto generatedOp : pattern->getGeneratedOps()) {
2452  unsigned generatedOpDepth = computeOpLegalizationDepth(
2453  generatedOp, minOpPatternDepth, legalizerPatterns);
2454  depth = std::max(depth, generatedOpDepth + 1);
2455  }
2456  patternsByDepth.emplace_back(pattern, depth);
2457 
2458  // Update the minimum depth of the pattern list.
2459  minDepth = std::min(minDepth, depth);
2460  }
2461 
2462  // If the operation only has one legalization pattern, there is no need to
2463  // sort them.
2464  if (patternsByDepth.size() == 1)
2465  return minDepth;
2466 
2467  // Sort the patterns by those likely to be the most beneficial.
2468  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2469  [](const std::pair<const Pattern *, unsigned> &lhs,
2470  const std::pair<const Pattern *, unsigned> &rhs) {
2471  // First sort by the smaller pattern legalization
2472  // depth.
2473  if (lhs.second != rhs.second)
2474  return lhs.second < rhs.second;
2475 
2476  // Then sort by the larger pattern benefit.
2477  auto lhsBenefit = lhs.first->getBenefit();
2478  auto rhsBenefit = rhs.first->getBenefit();
2479  return lhsBenefit > rhsBenefit;
2480  });
2481 
2482  // Update the legalization pattern to use the new sorted list.
2483  patterns.clear();
2484  for (auto &patternIt : patternsByDepth)
2485  patterns.push_back(patternIt.first);
2486  return minDepth;
2487 }
2488 
2489 //===----------------------------------------------------------------------===//
2490 // OperationConverter
2491 //===----------------------------------------------------------------------===//
2492 namespace {
2493 enum OpConversionMode {
2494  /// In this mode, the conversion will ignore failed conversions to allow
2495  /// illegal operations to co-exist in the IR.
2496  Partial,
2497 
2498  /// In this mode, all operations must be legal for the given target for the
2499  /// conversion to succeed.
2500  Full,
2501 
2502  /// In this mode, operations are analyzed for legality. No actual rewrites are
2503  /// applied to the operations on success.
2504  Analysis,
2505 };
2506 } // namespace
2507 
2508 namespace mlir {
2509 // This class converts operations to a given conversion target via a set of
2510 // rewrite patterns. The conversion behaves differently depending on the
2511 // conversion mode.
2513  explicit OperationConverter(const ConversionTarget &target,
2514  const FrozenRewritePatternSet &patterns,
2515  const ConversionConfig &config,
2516  OpConversionMode mode)
2517  : config(config), opLegalizer(target, patterns, this->config),
2518  mode(mode) {}
2519 
2520  /// Converts the given operations to the conversion target.
2522 
2523 private:
2524  /// Converts an operation with the given rewriter.
2525  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2526 
2527  /// This method is called after the conversion process to legalize any
2528  /// remaining artifacts and complete the conversion.
2529  LogicalResult finalize(ConversionPatternRewriter &rewriter);
2530 
2531  /// Legalize the types of converted block arguments.
2533  legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2534  ConversionPatternRewriterImpl &rewriterImpl);
2535 
2536  /// Legalize any unresolved type materializations.
2537  LogicalResult legalizeUnresolvedMaterializations(
2538  ConversionPatternRewriter &rewriter,
2539  ConversionPatternRewriterImpl &rewriterImpl,
2540  std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
2541 
2542  /// Legalize an operation result that was marked as "erased".
2544  legalizeErasedResult(Operation *op, OpResult result,
2545  ConversionPatternRewriterImpl &rewriterImpl);
2546 
2547  /// Legalize an operation result that was replaced with a value of a different
2548  /// type.
2549  LogicalResult legalizeChangedResultType(
2550  Operation *op, OpResult result, Value newValue,
2551  const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2552  ConversionPatternRewriterImpl &rewriterImpl,
2553  const DenseMap<Value, SmallVector<Value>> &inverseMapping);
2554 
2555  /// Dialect conversion configuration.
2556  ConversionConfig config;
2557 
2558  /// The legalizer to use when converting operations.
2559  OperationLegalizer opLegalizer;
2560 
2561  /// The conversion mode to use when legalizing operations.
2562  OpConversionMode mode;
2563 };
2564 } // namespace mlir
2565 
2566 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2567  Operation *op) {
2568  // Legalize the given operation.
2569  if (failed(opLegalizer.legalize(op, rewriter))) {
2570  // Handle the case of a failed conversion for each of the different modes.
2571  // Full conversions expect all operations to be converted.
2572  if (mode == OpConversionMode::Full)
2573  return op->emitError()
2574  << "failed to legalize operation '" << op->getName() << "'";
2575  // Partial conversions allow conversions to fail iff the operation was not
2576  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2577  // set, non-legalizable ops are added to that set.
2578  if (mode == OpConversionMode::Partial) {
2579  if (opLegalizer.isIllegal(op))
2580  return op->emitError()
2581  << "failed to legalize operation '" << op->getName()
2582  << "' that was explicitly marked illegal";
2583  if (config.unlegalizedOps)
2584  config.unlegalizedOps->insert(op);
2585  }
2586  } else if (mode == OpConversionMode::Analysis) {
2587  // Analysis conversions don't fail if any operations fail to legalize,
2588  // they are only interested in the operations that were successfully
2589  // legalized.
2590  if (config.legalizableOps)
2591  config.legalizableOps->insert(op);
2592  }
2593  return success();
2594 }
2595 
2597  if (ops.empty())
2598  return success();
2599  const ConversionTarget &target = opLegalizer.getTarget();
2600 
2601  // Compute the set of operations and blocks to convert.
2602  SmallVector<Operation *> toConvert;
2603  for (auto *op : ops) {
2605  [&](Operation *op) {
2606  toConvert.push_back(op);
2607  // Don't check this operation's children for conversion if the
2608  // operation is recursively legal.
2609  auto legalityInfo = target.isLegal(op);
2610  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2611  return WalkResult::skip();
2612  return WalkResult::advance();
2613  });
2614  }
2615 
2616  // Convert each operation and discard rewrites on failure.
2617  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2618  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2619 
2620  for (auto *op : toConvert)
2621  if (failed(convert(rewriter, op)))
2622  return rewriterImpl.undoRewrites(), failure();
2623 
2624  // Now that all of the operations have been converted, finalize the conversion
2625  // process to ensure any lingering conversion artifacts are cleaned up and
2626  // legalized.
2627  if (failed(finalize(rewriter)))
2628  return rewriterImpl.undoRewrites(), failure();
2629 
2630  // After a successful conversion, apply rewrites if this is not an analysis
2631  // conversion.
2632  if (mode == OpConversionMode::Analysis) {
2633  rewriterImpl.undoRewrites();
2634  } else {
2635  rewriterImpl.applyRewrites();
2636  }
2637  return success();
2638 }
2639 
2641 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2642  std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
2643  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2644  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2645  inverseMapping)) ||
2646  failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2647  return failure();
2648 
2649  // Process requested operation replacements.
2650  for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
2651  auto *opReplacement =
2652  dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
2653  if (!opReplacement || !opReplacement->hasChangedResults())
2654  continue;
2655  Operation *op = opReplacement->getOperation();
2656  for (OpResult result : op->getResults()) {
2657  Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2658 
2659  // If the operation result was replaced with null, all of the uses of this
2660  // value should be replaced.
2661  if (!newValue) {
2662  if (failed(legalizeErasedResult(op, result, rewriterImpl)))
2663  return failure();
2664  continue;
2665  }
2666 
2667  // Otherwise, check to see if the type of the result changed.
2668  if (result.getType() == newValue.getType())
2669  continue;
2670 
2671  // Compute the inverse mapping only if it is really needed.
2672  if (!inverseMapping)
2673  inverseMapping = rewriterImpl.mapping.getInverse();
2674 
2675  // Legalize this result.
2676  rewriter.setInsertionPoint(op);
2677  if (failed(legalizeChangedResultType(
2678  op, result, newValue, opReplacement->getConverter(), rewriter,
2679  rewriterImpl, *inverseMapping)))
2680  return failure();
2681  }
2682  }
2683  return success();
2684 }
2685 
2686 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2687  ConversionPatternRewriter &rewriter,
2688  ConversionPatternRewriterImpl &rewriterImpl) {
2689  // Functor used to check if all users of a value will be dead after
2690  // conversion.
2691  auto findLiveUser = [&](Value val) {
2692  auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2693  return rewriterImpl.isOpIgnored(user);
2694  });
2695  return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2696  };
2697  // Note: `rewrites` may be reallocated as the loop is running.
2698  for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
2699  ++i) {
2700  auto &rewrite = rewriterImpl.rewrites[i];
2701  if (auto *blockTypeConversionRewrite =
2702  dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
2703  if (failed(blockTypeConversionRewrite->materializeLiveConversions(
2704  findLiveUser)))
2705  return failure();
2706  }
2707  return success();
2708 }
2709 
2710 /// Replace the results of a materialization operation with the given values.
2711 static void
2713  ResultRange matResults, ValueRange values,
2714  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2715  matResults.replaceAllUsesWith(values);
2716 
2717  // For each of the materialization results, update the inverse mappings to
2718  // point to the replacement values.
2719  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
2720  auto inverseMapIt = inverseMapping.find(matResult);
2721  if (inverseMapIt == inverseMapping.end())
2722  continue;
2723 
2724  // Update the reverse mapping, or remove the mapping if we couldn't update
2725  // it. Not being able to update signals that the mapping would have become
2726  // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
2727  // propagated through temporary materializations. We simply drop the
2728  // mapping, and let the post-conversion replacement logic handle updating
2729  // uses.
2730  for (Value inverseMapVal : inverseMapIt->second)
2731  if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
2732  rewriterImpl.mapping.erase(inverseMapVal);
2733  }
2734 }
2735 
2736 /// Compute all of the unresolved materializations that will persist beyond the
2737 /// conversion process, and require inserting a proper user materialization for.
2740  &materializationOps,
2741  ConversionPatternRewriter &rewriter,
2742  ConversionPatternRewriterImpl &rewriterImpl,
2743  DenseMap<Value, SmallVector<Value>> &inverseMapping,
2744  SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
2745  auto isLive = [&](Value value) {
2746  auto findFn = [&](Operation *user) {
2747  auto matIt = materializationOps.find(user);
2748  if (matIt != materializationOps.end())
2749  return !necessaryMaterializations.count(matIt->second);
2750  return rewriterImpl.isOpIgnored(user);
2751  };
2752  // This value may be replacing another value that has a live user.
2753  for (Value inv : inverseMapping.lookup(value))
2754  if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2755  return true;
2756  // Or have live users itself.
2757  return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
2758  };
2759 
2760  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
2761  [&](Value invalidRoot, Value value, Type type) {
2762  // Check to see if the input operation was remapped to a variant of the
2763  // output.
2764  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2765  if (remappedValue.getType() == type && remappedValue != invalidRoot)
2766  return remappedValue;
2767 
2768  // Check to see if the input is a materialization operation that
2769  // provides an inverse conversion. We just check blindly for
2770  // UnrealizedConversionCastOp here, but it has no effect on correctness.
2771  auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
2772  if (inputCastOp && inputCastOp->getNumOperands() == 1)
2773  return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2774  type);
2775 
2776  return Value();
2777  };
2778 
2780  for (auto &rewrite : rewriterImpl.rewrites) {
2781  auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
2782  if (!mat)
2783  continue;
2784  materializationOps.try_emplace(mat->getOperation(), mat);
2785  worklist.insert(mat);
2786  }
2787  while (!worklist.empty()) {
2788  UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
2789  UnrealizedConversionCastOp op = mat->getOperation();
2790 
2791  // We currently only handle target materializations here.
2792  assert(op->getNumResults() == 1 && "unexpected materialization type");
2793  OpResult opResult = op->getOpResult(0);
2794  Type outputType = opResult.getType();
2795  Operation::operand_range inputOperands = op.getOperands();
2796 
2797  // Try to forward propagate operands for user conversion casts that result
2798  // in the input types of the current cast.
2799  for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
2800  auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2801  if (!castOp)
2802  continue;
2803  if (castOp->getResultTypes() == inputOperands.getTypes()) {
2804  replaceMaterialization(rewriterImpl, opResult, inputOperands,
2805  inverseMapping);
2806  necessaryMaterializations.remove(materializationOps.lookup(user));
2807  }
2808  }
2809 
2810  // Try to avoid materializing a resolved materialization if possible.
2811  // Handle the case of a 1-1 materialization.
2812  if (inputOperands.size() == 1) {
2813  // Check to see if the input operation was remapped to a variant of the
2814  // output.
2815  Value remappedValue =
2816  lookupRemappedValue(opResult, inputOperands[0], outputType);
2817  if (remappedValue && remappedValue != opResult) {
2818  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2819  inverseMapping);
2820  necessaryMaterializations.remove(mat);
2821  continue;
2822  }
2823  } else {
2824  // TODO: Avoid materializing other types of conversions here.
2825  }
2826 
2827  // Check to see if this is an argument materialization.
2828  auto isBlockArg = [](Value v) { return isa<BlockArgument>(v); };
2829  if (llvm::any_of(op->getOperands(), isBlockArg) ||
2830  llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) {
2831  mat->setMaterializationKind(MaterializationKind::Argument);
2832  }
2833 
2834  // If the materialization does not have any live users, we don't need to
2835  // generate a user materialization for it.
2836  // FIXME: For argument materializations, we currently need to check if any
2837  // of the inverse mapped values are used because some patterns expect blind
2838  // value replacement even if the types differ in some cases. When those
2839  // patterns are fixed, we can drop the argument special case here.
2840  bool isMaterializationLive = isLive(opResult);
2841  if (mat->getMaterializationKind() == MaterializationKind::Argument)
2842  isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
2843  if (!isMaterializationLive)
2844  continue;
2845  if (!necessaryMaterializations.insert(mat))
2846  continue;
2847 
2848  // Reprocess input materializations to see if they have an updated status.
2849  for (Value input : inputOperands) {
2850  if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2851  if (auto *mat = materializationOps.lookup(parentOp))
2852  worklist.insert(mat);
2853  }
2854  }
2855  }
2856 }
2857 
2858 /// Legalize the given unresolved materialization. Returns success if the
2859 /// materialization was legalized, failure otherise.
2861  UnresolvedMaterializationRewrite &mat,
2863  &materializationOps,
2864  ConversionPatternRewriter &rewriter,
2865  ConversionPatternRewriterImpl &rewriterImpl,
2866  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2867  auto findLiveUser = [&](auto &&users) {
2868  auto liveUserIt = llvm::find_if_not(
2869  users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
2870  return liveUserIt == users.end() ? nullptr : *liveUserIt;
2871  };
2872 
2873  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
2874  [&](Value value, Type type) {
2875  // Check to see if the input operation was remapped to a variant of the
2876  // output.
2877  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2878  if (remappedValue.getType() == type)
2879  return remappedValue;
2880  return Value();
2881  };
2882 
2883  UnrealizedConversionCastOp op = mat.getOperation();
2884  if (!rewriterImpl.ignoredOps.insert(op))
2885  return success();
2886 
2887  // We currently only handle target materializations here.
2888  OpResult opResult = op->getOpResult(0);
2889  Operation::operand_range inputOperands = op.getOperands();
2890  Type outputType = opResult.getType();
2891 
2892  // If any input to this materialization is another materialization, resolve
2893  // the input first.
2894  for (Value value : op->getOperands()) {
2895  auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
2896  if (!valueCast)
2897  continue;
2898 
2899  auto matIt = materializationOps.find(valueCast);
2900  if (matIt != materializationOps.end())
2902  *matIt->second, materializationOps, rewriter, rewriterImpl,
2903  inverseMapping)))
2904  return failure();
2905  }
2906 
2907  // Perform a last ditch attempt to avoid materializing a resolved
2908  // materialization if possible.
2909  // Handle the case of a 1-1 materialization.
2910  if (inputOperands.size() == 1) {
2911  // Check to see if the input operation was remapped to a variant of the
2912  // output.
2913  Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2914  if (remappedValue && remappedValue != opResult) {
2915  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2916  inverseMapping);
2917  return success();
2918  }
2919  } else {
2920  // TODO: Avoid materializing other types of conversions here.
2921  }
2922 
2923  // Try to materialize the conversion.
2924  if (const TypeConverter *converter = mat.getConverter()) {
2925  // FIXME: Determine a suitable insertion location when there are multiple
2926  // inputs.
2927  if (inputOperands.size() == 1)
2928  rewriter.setInsertionPointAfterValue(inputOperands.front());
2929  else
2930  rewriter.setInsertionPoint(op);
2931 
2932  Value newMaterialization;
2933  switch (mat.getMaterializationKind()) {
2935  // Try to materialize an argument conversion.
2936  // FIXME: The current argument materialization hook expects the original
2937  // output type, even though it doesn't use that as the actual output type
2938  // of the generated IR. The output type is just used as an indicator of
2939  // the type of materialization to do. This behavior is really awkward in
2940  // that it diverges from the behavior of the other hooks, and can be
2941  // easily misunderstood. We should clean up the argument hooks to better
2942  // represent the desired invariants we actually care about.
2943  newMaterialization = converter->materializeArgumentConversion(
2944  rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
2945  if (newMaterialization)
2946  break;
2947 
2948  // If an argument materialization failed, fallback to trying a target
2949  // materialization.
2950  [[fallthrough]];
2951  case MaterializationKind::Target:
2952  newMaterialization = converter->materializeTargetConversion(
2953  rewriter, op->getLoc(), outputType, inputOperands);
2954  break;
2955  }
2956  if (newMaterialization) {
2957  replaceMaterialization(rewriterImpl, opResult, newMaterialization,
2958  inverseMapping);
2959  return success();
2960  }
2961  }
2962 
2964  << "failed to legalize unresolved materialization "
2965  "from "
2966  << inputOperands.getTypes() << " to " << outputType
2967  << " that remained live after conversion";
2968  if (Operation *liveUser = findLiveUser(op->getUsers())) {
2969  diag.attachNote(liveUser->getLoc())
2970  << "see existing live user here: " << *liveUser;
2971  }
2972  return failure();
2973 }
2974 
2975 LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2976  ConversionPatternRewriter &rewriter,
2977  ConversionPatternRewriterImpl &rewriterImpl,
2978  std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
2979  inverseMapping = rewriterImpl.mapping.getInverse();
2980 
2981  // As an initial step, compute all of the inserted materializations that we
2982  // expect to persist beyond the conversion process.
2984  SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
2985  computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
2986  *inverseMapping, necessaryMaterializations);
2987 
2988  // Once computed, legalize any necessary materializations.
2989  for (auto *mat : necessaryMaterializations) {
2991  *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
2992  return failure();
2993  }
2994  return success();
2995 }
2996 
2997 LogicalResult OperationConverter::legalizeErasedResult(
2998  Operation *op, OpResult result,
2999  ConversionPatternRewriterImpl &rewriterImpl) {
3000  // If the operation result was replaced with null, all of the uses of this
3001  // value should be replaced.
3002  auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
3003  return rewriterImpl.isOpIgnored(user);
3004  });
3005  if (liveUserIt != result.user_end()) {
3006  InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
3007  << op->getName() << "' marked as erased";
3008  diag.attachNote(liveUserIt->getLoc())
3009  << "found live user of result #" << result.getResultNumber() << ": "
3010  << *liveUserIt;
3011  return failure();
3012  }
3013  return success();
3014 }
3015 
3016 /// Finds a user of the given value, or of any other value that the given value
3017 /// replaced, that was not replaced in the conversion process.
3019  Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
3020  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
3021  SmallVector<Value> worklist(1, initialValue);
3022  while (!worklist.empty()) {
3023  Value value = worklist.pop_back_val();
3024 
3025  // Walk the users of this value to see if there are any live users that
3026  // weren't replaced during conversion.
3027  auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
3028  return rewriterImpl.isOpIgnored(user);
3029  });
3030  if (liveUserIt != value.user_end())
3031  return *liveUserIt;
3032  auto mapIt = inverseMapping.find(value);
3033  if (mapIt != inverseMapping.end())
3034  worklist.append(mapIt->second);
3035  }
3036  return nullptr;
3037 }
3038 
3039 LogicalResult OperationConverter::legalizeChangedResultType(
3040  Operation *op, OpResult result, Value newValue,
3041  const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
3042  ConversionPatternRewriterImpl &rewriterImpl,
3043  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
3044  Operation *liveUser =
3045  findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
3046  if (!liveUser)
3047  return success();
3048 
3049  // Functor used to emit a conversion error for a failed materialization.
3050  auto emitConversionError = [&] {
3052  << "failed to materialize conversion for result #"
3053  << result.getResultNumber() << " of operation '"
3054  << op->getName()
3055  << "' that remained live after conversion";
3056  diag.attachNote(liveUser->getLoc())
3057  << "see existing live user here: " << *liveUser;
3058  return failure();
3059  };
3060 
3061  // If the replacement has a type converter, attempt to materialize a
3062  // conversion back to the original type.
3063  if (!replConverter)
3064  return emitConversionError();
3065 
3066  // Materialize a conversion for this live result value.
3067  Type resultType = result.getType();
3068  Value convertedValue = replConverter->materializeSourceConversion(
3069  rewriter, op->getLoc(), resultType, newValue);
3070  if (!convertedValue)
3071  return emitConversionError();
3072 
3073  rewriterImpl.mapping.map(result, convertedValue);
3074  return success();
3075 }
3076 
3077 //===----------------------------------------------------------------------===//
3078 // Type Conversion
3079 //===----------------------------------------------------------------------===//
3080 
3082  ArrayRef<Type> types) {
3083  assert(!types.empty() && "expected valid types");
3084  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3085  addInputs(types);
3086 }
3087 
3089  assert(!types.empty() &&
3090  "1->0 type remappings don't need to be added explicitly");
3091  argTypes.append(types.begin(), types.end());
3092 }
3093 
3094 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3095  unsigned newInputNo,
3096  unsigned newInputCount) {
3097  assert(!remappedInputs[origInputNo] && "input has already been remapped");
3098  assert(newInputCount != 0 && "expected valid input count");
3099  remappedInputs[origInputNo] =
3100  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
3101 }
3102 
3104  Value replacementValue) {
3105  assert(!remappedInputs[origInputNo] && "input has already been remapped");
3106  remappedInputs[origInputNo] =
3107  InputMapping{origInputNo, /*size=*/0, replacementValue};
3108 }
3109 
3111  SmallVectorImpl<Type> &results) const {
3112  {
3113  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3114  std::defer_lock);
3116  cacheReadLock.lock();
3117  auto existingIt = cachedDirectConversions.find(t);
3118  if (existingIt != cachedDirectConversions.end()) {
3119  if (existingIt->second)
3120  results.push_back(existingIt->second);
3121  return success(existingIt->second != nullptr);
3122  }
3123  auto multiIt = cachedMultiConversions.find(t);
3124  if (multiIt != cachedMultiConversions.end()) {
3125  results.append(multiIt->second.begin(), multiIt->second.end());
3126  return success();
3127  }
3128  }
3129  // Walk the added converters in reverse order to apply the most recently
3130  // registered first.
3131  size_t currentCount = results.size();
3132 
3133  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3134  std::defer_lock);
3135 
3136  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
3137  if (std::optional<LogicalResult> result = converter(t, results)) {
3139  cacheWriteLock.lock();
3140  if (!succeeded(*result)) {
3141  cachedDirectConversions.try_emplace(t, nullptr);
3142  return failure();
3143  }
3144  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3145  if (newTypes.size() == 1)
3146  cachedDirectConversions.try_emplace(t, newTypes.front());
3147  else
3148  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3149  return success();
3150  }
3151  }
3152  return failure();
3153 }
3154 
3156  // Use the multi-type result version to convert the type.
3157  SmallVector<Type, 1> results;
3158  if (failed(convertType(t, results)))
3159  return nullptr;
3160 
3161  // Check to ensure that only one type was produced.
3162  return results.size() == 1 ? results.front() : nullptr;
3163 }
3164 
3167  SmallVectorImpl<Type> &results) const {
3168  for (Type type : types)
3169  if (failed(convertType(type, results)))
3170  return failure();
3171  return success();
3172 }
3173 
3174 bool TypeConverter::isLegal(Type type) const {
3175  return convertType(type) == type;
3176 }
3178  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
3179 }
3180 
3181 bool TypeConverter::isLegal(Region *region) const {
3182  return llvm::all_of(*region, [this](Block &block) {
3183  return isLegal(block.getArgumentTypes());
3184  });
3185 }
3186 
3187 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3188  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
3189 }
3190 
3193  SignatureConversion &result) const {
3194  // Try to convert the given input type.
3195  SmallVector<Type, 1> convertedTypes;
3196  if (failed(convertType(type, convertedTypes)))
3197  return failure();
3198 
3199  // If this argument is being dropped, there is nothing left to do.
3200  if (convertedTypes.empty())
3201  return success();
3202 
3203  // Otherwise, add the new inputs.
3204  result.addInputs(inputNo, convertedTypes);
3205  return success();
3206 }
3209  SignatureConversion &result,
3210  unsigned origInputOffset) const {
3211  for (unsigned i = 0, e = types.size(); i != e; ++i)
3212  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3213  return failure();
3214  return success();
3215 }
3216 
3217 Value TypeConverter::materializeConversion(
3218  ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
3219  Location loc, Type resultType, ValueRange inputs) const {
3220  for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
3221  if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
3222  return *result;
3223  return nullptr;
3224 }
3225 
3226 std::optional<TypeConverter::SignatureConversion>
3228  SignatureConversion conversion(block->getNumArguments());
3229  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3230  return std::nullopt;
3231  return conversion;
3232 }
3233 
3234 //===----------------------------------------------------------------------===//
3235 // Type attribute conversion
3236 //===----------------------------------------------------------------------===//
3239  return AttributeConversionResult(attr, resultTag);
3240 }
3241 
3244  return AttributeConversionResult(nullptr, naTag);
3245 }
3246 
3249  return AttributeConversionResult(nullptr, abortTag);
3250 }
3251 
3253  return impl.getInt() == resultTag;
3254 }
3255 
3257  return impl.getInt() == naTag;
3258 }
3259 
3261  return impl.getInt() == abortTag;
3262 }
3263 
3265  assert(hasResult() && "Cannot get result from N/A or abort");
3266  return impl.getPointer();
3267 }
3268 
3269 std::optional<Attribute>
3271  for (const TypeAttributeConversionCallbackFn &fn :
3272  llvm::reverse(typeAttributeConversions)) {
3273  AttributeConversionResult res = fn(type, attr);
3274  if (res.hasResult())
3275  return res.getResult();
3276  if (res.isAbort())
3277  return std::nullopt;
3278  }
3279  return std::nullopt;
3280 }
3281 
3282 //===----------------------------------------------------------------------===//
3283 // FunctionOpInterfaceSignatureConversion
3284 //===----------------------------------------------------------------------===//
3285 
3286 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3287  const TypeConverter &typeConverter,
3288  ConversionPatternRewriter &rewriter) {
3289  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3290  if (!type)
3291  return failure();
3292 
3293  // Convert the original function types.
3294  TypeConverter::SignatureConversion result(type.getNumInputs());
3295  SmallVector<Type, 1> newResults;
3296  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3297  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3298  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3299  typeConverter, &result)))
3300  return failure();
3301 
3302  // Update the function signature in-place.
3303  auto newType = FunctionType::get(rewriter.getContext(),
3304  result.getConvertedTypes(), newResults);
3305 
3306  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3307 
3308  return success();
3309 }
3310 
3311 /// Create a default conversion pattern that rewrites the type signature of a
3312 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3313 /// represent their type.
3314 namespace {
3315 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3316  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3317  MLIRContext *ctx,
3318  const TypeConverter &converter)
3319  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3320 
3322  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3323  ConversionPatternRewriter &rewriter) const override {
3324  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3325  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3326  }
3327 };
3328 
3329 struct AnyFunctionOpInterfaceSignatureConversion
3330  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3332 
3334  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3335  ConversionPatternRewriter &rewriter) const override {
3336  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3337  }
3338 };
3339 } // namespace
3340 
3343  const TypeConverter &converter,
3344  ConversionPatternRewriter &rewriter) {
3345  assert(op && "Invalid op");
3346  Location loc = op->getLoc();
3347  if (converter.isLegal(op))
3348  return rewriter.notifyMatchFailure(loc, "op already legal");
3349 
3350  OperationState newOp(loc, op->getName());
3351  newOp.addOperands(operands);
3352 
3353  SmallVector<Type> newResultTypes;
3354  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3355  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3356 
3357  newOp.addTypes(newResultTypes);
3358  newOp.addAttributes(op->getAttrs());
3359  return rewriter.create(newOp);
3360 }
3361 
3363  StringRef functionLikeOpName, RewritePatternSet &patterns,
3364  const TypeConverter &converter) {
3365  patterns.add<FunctionOpInterfaceSignatureConversion>(
3366  functionLikeOpName, patterns.getContext(), converter);
3367 }
3368 
3370  RewritePatternSet &patterns, const TypeConverter &converter) {
3371  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3372  converter, patterns.getContext());
3373 }
3374 
3375 //===----------------------------------------------------------------------===//
3376 // ConversionTarget
3377 //===----------------------------------------------------------------------===//
3378 
3380  LegalizationAction action) {
3381  legalOperations[op].action = action;
3382 }
3383 
3385  LegalizationAction action) {
3386  for (StringRef dialect : dialectNames)
3387  legalDialects[dialect] = action;
3388 }
3389 
3391  -> std::optional<LegalizationAction> {
3392  std::optional<LegalizationInfo> info = getOpInfo(op);
3393  return info ? info->action : std::optional<LegalizationAction>();
3394 }
3395 
3397  -> std::optional<LegalOpDetails> {
3398  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3399  if (!info)
3400  return std::nullopt;
3401 
3402  // Returns true if this operation instance is known to be legal.
3403  auto isOpLegal = [&] {
3404  // Handle dynamic legality either with the provided legality function.
3405  if (info->action == LegalizationAction::Dynamic) {
3406  std::optional<bool> result = info->legalityFn(op);
3407  if (result)
3408  return *result;
3409  }
3410 
3411  // Otherwise, the operation is only legal if it was marked 'Legal'.
3412  return info->action == LegalizationAction::Legal;
3413  };
3414  if (!isOpLegal())
3415  return std::nullopt;
3416 
3417  // This operation is legal, compute any additional legality information.
3418  LegalOpDetails legalityDetails;
3419  if (info->isRecursivelyLegal) {
3420  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3421  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3422  legalityDetails.isRecursivelyLegal =
3423  legalityFnIt->second(op).value_or(true);
3424  } else {
3425  legalityDetails.isRecursivelyLegal = true;
3426  }
3427  }
3428  return legalityDetails;
3429 }
3430 
3432  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3433  if (!info)
3434  return false;
3435 
3436  if (info->action == LegalizationAction::Dynamic) {
3437  std::optional<bool> result = info->legalityFn(op);
3438  if (!result)
3439  return false;
3440 
3441  return !(*result);
3442  }
3443 
3444  return info->action == LegalizationAction::Illegal;
3445 }
3446 
3450  if (!oldCallback)
3451  return newCallback;
3452 
3453  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3454  Operation *op) -> std::optional<bool> {
3455  if (std::optional<bool> result = newCl(op))
3456  return *result;
3457 
3458  return oldCl(op);
3459  };
3460  return chain;
3461 }
3462 
3463 void ConversionTarget::setLegalityCallback(
3464  OperationName name, const DynamicLegalityCallbackFn &callback) {
3465  assert(callback && "expected valid legality callback");
3466  auto *infoIt = legalOperations.find(name);
3467  assert(infoIt != legalOperations.end() &&
3468  infoIt->second.action == LegalizationAction::Dynamic &&
3469  "expected operation to already be marked as dynamically legal");
3470  infoIt->second.legalityFn =
3471  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3472 }
3473 
3475  OperationName name, const DynamicLegalityCallbackFn &callback) {
3476  auto *infoIt = legalOperations.find(name);
3477  assert(infoIt != legalOperations.end() &&
3478  infoIt->second.action != LegalizationAction::Illegal &&
3479  "expected operation to already be marked as legal");
3480  infoIt->second.isRecursivelyLegal = true;
3481  if (callback)
3482  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3483  std::move(opRecursiveLegalityFns[name]), callback);
3484  else
3485  opRecursiveLegalityFns.erase(name);
3486 }
3487 
3488 void ConversionTarget::setLegalityCallback(
3489  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3490  assert(callback && "expected valid legality callback");
3491  for (StringRef dialect : dialects)
3492  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3493  std::move(dialectLegalityFns[dialect]), callback);
3494 }
3495 
3496 void ConversionTarget::setLegalityCallback(
3497  const DynamicLegalityCallbackFn &callback) {
3498  assert(callback && "expected valid legality callback");
3499  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3500 }
3501 
3502 auto ConversionTarget::getOpInfo(OperationName op) const
3503  -> std::optional<LegalizationInfo> {
3504  // Check for info for this specific operation.
3505  const auto *it = legalOperations.find(op);
3506  if (it != legalOperations.end())
3507  return it->second;
3508  // Check for info for the parent dialect.
3509  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3510  if (dialectIt != legalDialects.end()) {
3511  DynamicLegalityCallbackFn callback;
3512  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3513  if (dialectFn != dialectLegalityFns.end())
3514  callback = dialectFn->second;
3515  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3516  callback};
3517  }
3518  // Otherwise, check if we mark unknown operations as dynamic.
3519  if (unknownLegalityFn)
3520  return LegalizationInfo{LegalizationAction::Dynamic,
3521  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3522  return std::nullopt;
3523 }
3524 
3525 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3526 //===----------------------------------------------------------------------===//
3527 // PDL Configuration
3528 //===----------------------------------------------------------------------===//
3529 
3531  auto &rewriterImpl =
3532  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3533  rewriterImpl.currentTypeConverter = getTypeConverter();
3534 }
3535 
3537  auto &rewriterImpl =
3538  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3539  rewriterImpl.currentTypeConverter = nullptr;
3540 }
3541 
3542 /// Remap the given value using the rewriter and the type converter in the
3543 /// provided config.
3546  SmallVector<Value> mappedValues;
3547  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3548  return failure();
3549  return std::move(mappedValues);
3550 }
3551 
3553  patterns.getPDLPatterns().registerRewriteFunction(
3554  "convertValue",
3555  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3556  auto results = pdllConvertValues(
3557  static_cast<ConversionPatternRewriter &>(rewriter), value);
3558  if (failed(results))
3559  return failure();
3560  return results->front();
3561  });
3562  patterns.getPDLPatterns().registerRewriteFunction(
3563  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3564  return pdllConvertValues(
3565  static_cast<ConversionPatternRewriter &>(rewriter), values);
3566  });
3567  patterns.getPDLPatterns().registerRewriteFunction(
3568  "convertType",
3569  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3570  auto &rewriterImpl =
3571  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3572  if (const TypeConverter *converter =
3573  rewriterImpl.currentTypeConverter) {
3574  if (Type newType = converter->convertType(type))
3575  return newType;
3576  return failure();
3577  }
3578  return type;
3579  });
3580  patterns.getPDLPatterns().registerRewriteFunction(
3581  "convertTypes",
3582  [](PatternRewriter &rewriter,
3584  auto &rewriterImpl =
3585  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3586  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3587  if (!converter)
3588  return SmallVector<Type>(types);
3589 
3590  SmallVector<Type> remappedTypes;
3591  if (failed(converter->convertTypes(types, remappedTypes)))
3592  return failure();
3593  return std::move(remappedTypes);
3594  });
3595 }
3596 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3597 
3598 //===----------------------------------------------------------------------===//
3599 // Op Conversion Entry Points
3600 //===----------------------------------------------------------------------===//
3601 
3602 //===----------------------------------------------------------------------===//
3603 // Partial Conversion
3604 
3606  ArrayRef<Operation *> ops, const ConversionTarget &target,
3607  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3608  OperationConverter opConverter(target, patterns, config,
3609  OpConversionMode::Partial);
3610  return opConverter.convertOperations(ops);
3611 }
3614  const FrozenRewritePatternSet &patterns,
3615  ConversionConfig config) {
3616  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3617 }
3618 
3619 //===----------------------------------------------------------------------===//
3620 // Full Conversion
3621 
3623  const ConversionTarget &target,
3624  const FrozenRewritePatternSet &patterns,
3625  ConversionConfig config) {
3626  OperationConverter opConverter(target, patterns, config,
3627  OpConversionMode::Full);
3628  return opConverter.convertOperations(ops);
3629 }
3631  const ConversionTarget &target,
3632  const FrozenRewritePatternSet &patterns,
3633  ConversionConfig config) {
3634  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3635 }
3636 
3637 //===----------------------------------------------------------------------===//
3638 // Analysis Conversion
3639 
3642  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3643  OperationConverter opConverter(target, patterns, config,
3644  OpConversionMode::Analysis);
3645  return opConverter.convertOperations(ops);
3646 }
3649  const FrozenRewritePatternSet &patterns,
3650  ConversionConfig config) {
3651  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3652 }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(UnresolvedMaterializationRewrite &mat, DenseMap< Operation *, UnresolvedMaterializationRewrite * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Legalize the given unresolved materialization.
static RewriteTy * findSingleRewrite(R &&rewrites, Block *block)
Find the single rewrite object of the specified type and block among the given rewrites.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, ResultRange matResults, ValueRange values, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Replace the results of a materialization operation with the given values.
static void computeNecessaryMaterializations(DenseMap< Operation *, UnresolvedMaterializationRewrite * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping, SetVector< UnresolvedMaterializationRewrite * > &necessaryMaterializations)
Compute all of the unresolved materializations that will persist beyond the conversion process,...
static bool hasRewrite(R &&rewrites, Operation *op)
Return "true" if there is an operation rewrite that matches the specified rewrite type and operation ...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:324
Location getLoc() const
Return the location for this argument.
Definition: Value.h:330
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:137
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
bool empty()
Definition: Block.h:145
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:93
OpListType & getOperations()
Definition: Block.h:134
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
iterator end()
Definition: Block.h:141
iterator begin()
Definition: Block.h:140
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
LogicalResult convertNonEntryRegionTypes(Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
Base class for the conversion patterns.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:756
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:329
Block::iterator getPoint() const
Definition: Builders.h:342
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:339
Block * getBlock() const
Definition: Builders.h:341
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:423
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within 'results'.
Definition: Builders.cpp:482
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:263
This is a value defined by a result of an operation.
Definition: Value.h:453
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:465
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
OpResult getOpResult(unsigned idx)
Definition: Operation.h:416
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:605
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:892
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:42
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:128
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:93
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:89
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:281
MLIRContext * getContext() const
Definition: PatternMatch.h:812
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
Definition: PatternMatch.h:818
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:836
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:708
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
user_iterator user_end() const
Definition: Value.h:223
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:224
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition: Argument.h:64
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a single value that remaps an existing signature input...
A rewriter that keeps track of erased ops and blocks.
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
DenseSet< void * > erased
Pointers to all erased operations and blocks.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc, ValueRange inputs, Type origOutputType, Type outputType, const TypeConverter *converter)
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ValueRange newValues)
Notifies that an op is about to be replaced with the given values.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
LogicalResult convertNonEntryRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions={})
Convert the types of non-entry block arguments within the given region.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter)
Apply a signature conversion on the given region, using converter for materializations if not null.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
Value buildUnresolvedMaterialization(MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, Type origOutputType, const TypeConverter *converter)
Build an unresolved materialization operation given an output type and set of input operands.
FailureOr< Block * > convertBlockSignature(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion *conversion=nullptr)
Attempt to convert the signature of the given block, if successful a new block is returned containing...
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, const TypeConverter *converter)
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.