MLIR  19.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Iterators.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::detail;
29 
30 #define DEBUG_TYPE "dialect-conversion"
31 
32 /// A utility function to log a successful result for the given reason.
33 template <typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
35  LLVM_DEBUG({
36  os.unindent();
37  os.startLine() << "} -> SUCCESS";
38  if (!fmt.empty())
39  os.getOStream() << " : "
40  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41  os.getOStream() << "\n";
42  });
43 }
44 
45 /// A utility function to log a failure result for the given reason.
46 template <typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
48  LLVM_DEBUG({
49  os.unindent();
50  os.startLine() << "} -> FAILURE : "
51  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
52  << "\n";
53  });
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // ConversionValueMapping
58 //===----------------------------------------------------------------------===//
59 
60 namespace {
61 /// This class wraps a IRMapping to provide recursive lookup
62 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
63 struct ConversionValueMapping {
64  /// Lookup a mapped value within the map. If a mapping for the provided value
65  /// does not exist then return the provided value. If `desiredType` is
66  /// non-null, returns the most recently mapped value with that type. If an
67  /// operand of that type does not exist, defaults to normal behavior.
68  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
69 
70  /// Lookup a mapped value within the map, or return null if a mapping does not
71  /// exist. If a mapping exists, this follows the same behavior of
72  /// `lookupOrDefault`.
73  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
74 
75  /// Map a value to the one provided.
76  void map(Value oldVal, Value newVal) {
77  LLVM_DEBUG({
78  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
79  assert(it != oldVal && "inserting cyclic mapping");
80  });
81  mapping.map(oldVal, newVal);
82  }
83 
84  /// Try to map a value to the one provided. Returns false if a transitive
85  /// mapping from the new value to the old value already exists, true if the
86  /// map was updated.
87  bool tryMap(Value oldVal, Value newVal);
88 
89  /// Drop the last mapping for the given value.
90  void erase(Value value) { mapping.erase(value); }
91 
92  /// Returns the inverse raw value mapping (without recursive query support).
93  DenseMap<Value, SmallVector<Value>> getInverse() const {
95  for (auto &it : mapping.getValueMap())
96  inverse[it.second].push_back(it.first);
97  return inverse;
98  }
99 
100 private:
101  /// Current value mappings.
102  IRMapping mapping;
103 };
104 } // namespace
105 
106 Value ConversionValueMapping::lookupOrDefault(Value from,
107  Type desiredType) const {
108  // If there was no desired type, simply find the leaf value.
109  if (!desiredType) {
110  // If this value had a valid mapping, unmap that value as well in the case
111  // that it was also replaced.
112  while (auto mappedValue = mapping.lookupOrNull(from))
113  from = mappedValue;
114  return from;
115  }
116 
117  // Otherwise, try to find the deepest value that has the desired type.
118  Value desiredValue;
119  do {
120  if (from.getType() == desiredType)
121  desiredValue = from;
122 
123  Value mappedValue = mapping.lookupOrNull(from);
124  if (!mappedValue)
125  break;
126  from = mappedValue;
127  } while (true);
128 
129  // If the desired value was found use it, otherwise default to the leaf value.
130  return desiredValue ? desiredValue : from;
131 }
132 
133 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
134  Value result = lookupOrDefault(from, desiredType);
135  if (result == from || (desiredType && result.getType() != desiredType))
136  return nullptr;
137  return result;
138 }
139 
140 bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
141  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
142  if (it == oldVal)
143  return false;
144  map(oldVal, newVal);
145  return true;
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // Rewriter and Translation State
150 //===----------------------------------------------------------------------===//
151 namespace {
152 /// This class contains a snapshot of the current conversion rewriter state.
153 /// This is useful when saving and undoing a set of rewrites.
154 struct RewriterState {
155  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
156  unsigned numReplacedOps)
157  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
158  numReplacedOps(numReplacedOps) {}
159 
160  /// The current number of rewrites performed.
161  unsigned numRewrites;
162 
163  /// The current number of ignored operations.
164  unsigned numIgnoredOperations;
165 
166  /// The current number of replaced ops that are scheduled for erasure.
167  unsigned numReplacedOps;
168 };
169 
170 //===----------------------------------------------------------------------===//
171 // IR rewrites
172 //===----------------------------------------------------------------------===//
173 
174 /// An IR rewrite that can be committed (upon success) or rolled back (upon
175 /// failure).
176 ///
177 /// The dialect conversion keeps track of IR modifications (requested by the
178 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
179 /// are directly applied to the IR as the rewriter API is used, some are applied
180 /// partially, and some are delayed until the `IRRewrite` objects are committed.
181 class IRRewrite {
182 public:
183  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
184  /// Enum values are ordered, so that they can be used in `classof`: first all
185  /// block rewrites, then all operation rewrites.
186  enum class Kind {
187  // Block rewrites
188  CreateBlock,
189  EraseBlock,
190  InlineBlock,
191  MoveBlock,
192  BlockTypeConversion,
193  ReplaceBlockArg,
194  // Operation rewrites
195  MoveOperation,
196  ModifyOperation,
197  ReplaceOperation,
198  CreateOperation,
199  UnresolvedMaterialization
200  };
201 
202  virtual ~IRRewrite() = default;
203 
204  /// Roll back the rewrite. Operations may be erased during rollback.
205  virtual void rollback() = 0;
206 
207  /// Commit the rewrite. At this point, it is certain that the dialect
208  /// conversion will succeed. All IR modifications, except for operation/block
209  /// erasure, must be performed through the given rewriter.
210  ///
211  /// Instead of erasing operations/blocks, they should merely be unlinked
212  /// commit phase and finally be erased during the cleanup phase. This is
213  /// because internal dialect conversion state (such as `mapping`) may still
214  /// be using them.
215  ///
216  /// Any IR modification that was already performed before the commit phase
217  /// (e.g., insertion of an op) must be communicated to the listener that may
218  /// be attached to the given rewriter.
219  virtual void commit(RewriterBase &rewriter) {}
220 
221  /// Cleanup operations/blocks. Cleanup is called after commit.
222  virtual void cleanup(RewriterBase &rewriter) {}
223 
224  Kind getKind() const { return kind; }
225 
226  static bool classof(const IRRewrite *rewrite) { return true; }
227 
228 protected:
229  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
230  : kind(kind), rewriterImpl(rewriterImpl) {}
231 
232  const ConversionConfig &getConfig() const;
233 
234  const Kind kind;
235  ConversionPatternRewriterImpl &rewriterImpl;
236 };
237 
238 /// A block rewrite.
239 class BlockRewrite : public IRRewrite {
240 public:
241  /// Return the block that this rewrite operates on.
242  Block *getBlock() const { return block; }
243 
244  static bool classof(const IRRewrite *rewrite) {
245  return rewrite->getKind() >= Kind::CreateBlock &&
246  rewrite->getKind() <= Kind::ReplaceBlockArg;
247  }
248 
249 protected:
250  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
251  Block *block)
252  : IRRewrite(kind, rewriterImpl), block(block) {}
253 
254  // The block that this rewrite operates on.
255  Block *block;
256 };
257 
258 /// Creation of a block. Block creations are immediately reflected in the IR.
259 /// There is no extra work to commit the rewrite. During rollback, the newly
260 /// created block is erased.
261 class CreateBlockRewrite : public BlockRewrite {
262 public:
263  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
264  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
265 
266  static bool classof(const IRRewrite *rewrite) {
267  return rewrite->getKind() == Kind::CreateBlock;
268  }
269 
270  void commit(RewriterBase &rewriter) override {
271  // The block was already created and inserted. Just inform the listener.
272  if (auto *listener = rewriter.getListener())
273  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
274  }
275 
276  void rollback() override {
277  // Unlink all of the operations within this block, they will be deleted
278  // separately.
279  auto &blockOps = block->getOperations();
280  while (!blockOps.empty())
281  blockOps.remove(blockOps.begin());
282  block->dropAllUses();
283  if (block->getParent())
284  block->erase();
285  else
286  delete block;
287  }
288 };
289 
290 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
291 /// blocks are immediately unlinked, but only erased during cleanup. This makes
292 /// it easier to rollback a block erasure: the block is simply inserted into its
293 /// original location.
294 class EraseBlockRewrite : public BlockRewrite {
295 public:
296  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
297  Region *region, Block *insertBeforeBlock)
298  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), region(region),
299  insertBeforeBlock(insertBeforeBlock) {}
300 
301  static bool classof(const IRRewrite *rewrite) {
302  return rewrite->getKind() == Kind::EraseBlock;
303  }
304 
305  ~EraseBlockRewrite() override {
306  assert(!block &&
307  "rewrite was neither rolled back nor committed/cleaned up");
308  }
309 
310  void rollback() override {
311  // The block (owned by this rewrite) was not actually erased yet. It was
312  // just unlinked. Put it back into its original position.
313  assert(block && "expected block");
314  auto &blockList = region->getBlocks();
315  Region::iterator before = insertBeforeBlock
316  ? Region::iterator(insertBeforeBlock)
317  : blockList.end();
318  blockList.insert(before, block);
319  block = nullptr;
320  }
321 
322  void commit(RewriterBase &rewriter) override {
323  // Erase the block.
324  assert(block && "expected block");
325  assert(block->empty() && "expected empty block");
326 
327  // Notify the listener that the block is about to be erased.
328  if (auto *listener =
329  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
330  listener->notifyBlockErased(block);
331  }
332 
333  void cleanup(RewriterBase &rewriter) override {
334  // Erase the block.
335  block->dropAllDefinedValueUses();
336  delete block;
337  block = nullptr;
338  }
339 
340 private:
341  // The region in which this block was previously contained.
342  Region *region;
343 
344  // The original successor of this block before it was unlinked. "nullptr" if
345  // this block was the only block in the region.
346  Block *insertBeforeBlock;
347 };
348 
349 /// Inlining of a block. This rewrite is immediately reflected in the IR.
350 /// Note: This rewrite represents only the inlining of the operations. The
351 /// erasure of the inlined block is a separate rewrite.
352 class InlineBlockRewrite : public BlockRewrite {
353 public:
354  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
355  Block *sourceBlock, Block::iterator before)
356  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
357  sourceBlock(sourceBlock),
358  firstInlinedInst(sourceBlock->empty() ? nullptr
359  : &sourceBlock->front()),
360  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
361  // If a listener is attached to the dialect conversion, ops must be moved
362  // one-by-one. When they are moved in bulk, notifications cannot be sent
363  // because the ops that used to be in the source block at the time of the
364  // inlining (before the "commit" phase) are unknown at the time when
365  // notifications are sent (which is during the "commit" phase).
366  assert(!getConfig().listener &&
367  "InlineBlockRewrite not supported if listener is attached");
368  }
369 
370  static bool classof(const IRRewrite *rewrite) {
371  return rewrite->getKind() == Kind::InlineBlock;
372  }
373 
374  void rollback() override {
375  // Put the operations from the destination block (owned by the rewrite)
376  // back into the source block.
377  if (firstInlinedInst) {
378  assert(lastInlinedInst && "expected operation");
379  sourceBlock->getOperations().splice(sourceBlock->begin(),
380  block->getOperations(),
381  Block::iterator(firstInlinedInst),
382  ++Block::iterator(lastInlinedInst));
383  }
384  }
385 
386 private:
387  // The block that originally contained the operations.
388  Block *sourceBlock;
389 
390  // The first inlined operation.
391  Operation *firstInlinedInst;
392 
393  // The last inlined operation.
394  Operation *lastInlinedInst;
395 };
396 
397 /// Moving of a block. This rewrite is immediately reflected in the IR.
398 class MoveBlockRewrite : public BlockRewrite {
399 public:
400  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
401  Region *region, Block *insertBeforeBlock)
402  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
403  insertBeforeBlock(insertBeforeBlock) {}
404 
405  static bool classof(const IRRewrite *rewrite) {
406  return rewrite->getKind() == Kind::MoveBlock;
407  }
408 
409  void commit(RewriterBase &rewriter) override {
410  // The block was already moved. Just inform the listener.
411  if (auto *listener = rewriter.getListener()) {
412  // Note: `previousIt` cannot be passed because this is a delayed
413  // notification and iterators into past IR state cannot be represented.
414  listener->notifyBlockInserted(block, /*previous=*/region,
415  /*previousIt=*/{});
416  }
417  }
418 
419  void rollback() override {
420  // Move the block back to its original position.
421  Region::iterator before =
422  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
423  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
424  }
425 
426 private:
427  // The region in which this block was previously contained.
428  Region *region;
429 
430  // The original successor of this block before it was moved. "nullptr" if
431  // this block was the only block in the region.
432  Block *insertBeforeBlock;
433 };
434 
435 /// This structure contains the information pertaining to an argument that has
436 /// been converted.
437 struct ConvertedArgInfo {
438  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
439  Value castValue = nullptr)
440  : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
441 
442  /// The start index of in the new argument list that contains arguments that
443  /// replace the original.
444  unsigned newArgIdx;
445 
446  /// The number of arguments that replaced the original argument.
447  unsigned newArgSize;
448 
449  /// The cast value that was created to cast from the new arguments to the
450  /// old. This only used if 'newArgSize' > 1.
451  Value castValue;
452 };
453 
454 /// Block type conversion. This rewrite is partially reflected in the IR.
455 class BlockTypeConversionRewrite : public BlockRewrite {
456 public:
457  BlockTypeConversionRewrite(
458  ConversionPatternRewriterImpl &rewriterImpl, Block *block,
459  Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
460  const TypeConverter *converter)
461  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
462  origBlock(origBlock), argInfo(argInfo), converter(converter) {}
463 
464  static bool classof(const IRRewrite *rewrite) {
465  return rewrite->getKind() == Kind::BlockTypeConversion;
466  }
467 
468  /// Materialize any necessary conversions for converted arguments that have
469  /// live users, using the provided `findLiveUser` to search for a user that
470  /// survives the conversion process.
472  materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
473 
474  void commit(RewriterBase &rewriter) override;
475 
476  void rollback() override;
477 
478 private:
479  /// The original block that was requested to have its signature converted.
480  Block *origBlock;
481 
482  /// The conversion information for each of the arguments. The information is
483  /// std::nullopt if the argument was dropped during conversion.
485 
486  /// The type converter used to convert the arguments.
487  const TypeConverter *converter;
488 };
489 
490 /// Replacing a block argument. This rewrite is not immediately reflected in the
491 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
492 /// until the rewrite is committed.
493 class ReplaceBlockArgRewrite : public BlockRewrite {
494 public:
495  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
496  Block *block, BlockArgument arg)
497  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
498 
499  static bool classof(const IRRewrite *rewrite) {
500  return rewrite->getKind() == Kind::ReplaceBlockArg;
501  }
502 
503  void commit(RewriterBase &rewriter) override;
504 
505  void rollback() override;
506 
507 private:
508  BlockArgument arg;
509 };
510 
511 /// An operation rewrite.
512 class OperationRewrite : public IRRewrite {
513 public:
514  /// Return the operation that this rewrite operates on.
515  Operation *getOperation() const { return op; }
516 
517  static bool classof(const IRRewrite *rewrite) {
518  return rewrite->getKind() >= Kind::MoveOperation &&
519  rewrite->getKind() <= Kind::UnresolvedMaterialization;
520  }
521 
522 protected:
523  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
524  Operation *op)
525  : IRRewrite(kind, rewriterImpl), op(op) {}
526 
527  // The operation that this rewrite operates on.
528  Operation *op;
529 };
530 
531 /// Moving of an operation. This rewrite is immediately reflected in the IR.
532 class MoveOperationRewrite : public OperationRewrite {
533 public:
534  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
535  Operation *op, Block *block, Operation *insertBeforeOp)
536  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
537  insertBeforeOp(insertBeforeOp) {}
538 
539  static bool classof(const IRRewrite *rewrite) {
540  return rewrite->getKind() == Kind::MoveOperation;
541  }
542 
543  void commit(RewriterBase &rewriter) override {
544  // The operation was already moved. Just inform the listener.
545  if (auto *listener = rewriter.getListener()) {
546  // Note: `previousIt` cannot be passed because this is a delayed
547  // notification and iterators into past IR state cannot be represented.
548  listener->notifyOperationInserted(
549  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
550  /*insertPt=*/{}));
551  }
552  }
553 
554  void rollback() override {
555  // Move the operation back to its original position.
556  Block::iterator before =
557  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
558  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
559  }
560 
561 private:
562  // The block in which this operation was previously contained.
563  Block *block;
564 
565  // The original successor of this operation before it was moved. "nullptr"
566  // if this operation was the only operation in the region.
567  Operation *insertBeforeOp;
568 };
569 
570 /// In-place modification of an op. This rewrite is immediately reflected in
571 /// the IR. The previous state of the operation is stored in this object.
572 class ModifyOperationRewrite : public OperationRewrite {
573 public:
574  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
575  Operation *op)
576  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
577  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
578  operands(op->operand_begin(), op->operand_end()),
579  successors(op->successor_begin(), op->successor_end()) {
580  if (OpaqueProperties prop = op->getPropertiesStorage()) {
581  // Make a copy of the properties.
582  propertiesStorage = operator new(op->getPropertiesStorageSize());
583  OpaqueProperties propCopy(propertiesStorage);
584  name.initOpProperties(propCopy, /*init=*/prop);
585  }
586  }
587 
588  static bool classof(const IRRewrite *rewrite) {
589  return rewrite->getKind() == Kind::ModifyOperation;
590  }
591 
592  ~ModifyOperationRewrite() override {
593  assert(!propertiesStorage &&
594  "rewrite was neither committed nor rolled back");
595  }
596 
597  void commit(RewriterBase &rewriter) override {
598  // Notify the listener that the operation was modified in-place.
599  if (auto *listener =
600  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
601  listener->notifyOperationModified(op);
602 
603  if (propertiesStorage) {
604  OpaqueProperties propCopy(propertiesStorage);
605  // Note: The operation may have been erased in the mean time, so
606  // OperationName must be stored in this object.
607  name.destroyOpProperties(propCopy);
608  operator delete(propertiesStorage);
609  propertiesStorage = nullptr;
610  }
611  }
612 
613  void rollback() override {
614  op->setLoc(loc);
615  op->setAttrs(attrs);
616  op->setOperands(operands);
617  for (const auto &it : llvm::enumerate(successors))
618  op->setSuccessor(it.value(), it.index());
619  if (propertiesStorage) {
620  OpaqueProperties propCopy(propertiesStorage);
621  op->copyProperties(propCopy);
622  name.destroyOpProperties(propCopy);
623  operator delete(propertiesStorage);
624  propertiesStorage = nullptr;
625  }
626  }
627 
628 private:
629  OperationName name;
630  LocationAttr loc;
631  DictionaryAttr attrs;
632  SmallVector<Value, 8> operands;
633  SmallVector<Block *, 2> successors;
634  void *propertiesStorage = nullptr;
635 };
636 
637 /// Replacing an operation. Erasing an operation is treated as a special case
638 /// with "null" replacements. This rewrite is not immediately reflected in the
639 /// IR. An internal IR mapping is updated, but values are not replaced and the
640 /// original op is not erased until the rewrite is committed.
641 class ReplaceOperationRewrite : public OperationRewrite {
642 public:
643  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
644  Operation *op, const TypeConverter *converter,
645  bool changedResults)
646  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
647  converter(converter), changedResults(changedResults) {}
648 
649  static bool classof(const IRRewrite *rewrite) {
650  return rewrite->getKind() == Kind::ReplaceOperation;
651  }
652 
653  void commit(RewriterBase &rewriter) override;
654 
655  void rollback() override;
656 
657  void cleanup(RewriterBase &rewriter) override;
658 
659  const TypeConverter *getConverter() const { return converter; }
660 
661  bool hasChangedResults() const { return changedResults; }
662 
663 private:
664  /// An optional type converter that can be used to materialize conversions
665  /// between the new and old values if necessary.
666  const TypeConverter *converter;
667 
668  /// A boolean flag that indicates whether result types have changed or not.
669  bool changedResults;
670 };
671 
672 class CreateOperationRewrite : public OperationRewrite {
673 public:
674  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
675  Operation *op)
676  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
677 
678  static bool classof(const IRRewrite *rewrite) {
679  return rewrite->getKind() == Kind::CreateOperation;
680  }
681 
682  void commit(RewriterBase &rewriter) override {
683  // The operation was already created and inserted. Just inform the listener.
684  if (auto *listener = rewriter.getListener())
685  listener->notifyOperationInserted(op, /*previous=*/{});
686  }
687 
688  void rollback() override;
689 };
690 
691 /// The type of materialization.
692 enum MaterializationKind {
693  /// This materialization materializes a conversion for an illegal block
694  /// argument type, to a legal one.
695  Argument,
696 
697  /// This materialization materializes a conversion from an illegal type to a
698  /// legal one.
699  Target
700 };
701 
702 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
703 /// op. Unresolved materializations are erased at the end of the dialect
704 /// conversion.
705 class UnresolvedMaterializationRewrite : public OperationRewrite {
706 public:
707  UnresolvedMaterializationRewrite(
708  ConversionPatternRewriterImpl &rewriterImpl,
709  UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
710  MaterializationKind kind = MaterializationKind::Target,
711  Type origOutputType = nullptr)
712  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
713  converterAndKind(converter, kind), origOutputType(origOutputType) {}
714 
715  static bool classof(const IRRewrite *rewrite) {
716  return rewrite->getKind() == Kind::UnresolvedMaterialization;
717  }
718 
719  UnrealizedConversionCastOp getOperation() const {
720  return cast<UnrealizedConversionCastOp>(op);
721  }
722 
723  void rollback() override;
724 
725  void cleanup(RewriterBase &rewriter) override;
726 
727  /// Return the type converter of this materialization (which may be null).
728  const TypeConverter *getConverter() const {
729  return converterAndKind.getPointer();
730  }
731 
732  /// Return the kind of this materialization.
733  MaterializationKind getMaterializationKind() const {
734  return converterAndKind.getInt();
735  }
736 
737  /// Set the kind of this materialization.
738  void setMaterializationKind(MaterializationKind kind) {
739  converterAndKind.setInt(kind);
740  }
741 
742  /// Return the original illegal output type of the input values.
743  Type getOrigOutputType() const { return origOutputType; }
744 
745 private:
746  /// The corresponding type converter to use when resolving this
747  /// materialization, and the kind of this materialization.
748  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
749  converterAndKind;
750 
751  /// The original output type. This is only used for argument conversions.
752  Type origOutputType;
753 };
754 } // namespace
755 
756 /// Return "true" if there is an operation rewrite that matches the specified
757 /// rewrite type and operation among the given rewrites.
758 template <typename RewriteTy, typename R>
759 static bool hasRewrite(R &&rewrites, Operation *op) {
760  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
761  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
762  return rewriteTy && rewriteTy->getOperation() == op;
763  });
764 }
765 
766 /// Find the single rewrite object of the specified type and block among the
767 /// given rewrites. In debug mode, asserts that there is mo more than one such
768 /// object. Return "nullptr" if no object was found.
769 template <typename RewriteTy, typename R>
770 static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
771  RewriteTy *result = nullptr;
772  for (auto &rewrite : rewrites) {
773  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
774  if (rewriteTy && rewriteTy->getBlock() == block) {
775 #ifndef NDEBUG
776  assert(!result && "expected single matching rewrite");
777  result = rewriteTy;
778 #else
779  return rewriteTy;
780 #endif // NDEBUG
781  }
782  }
783  return result;
784 }
785 
786 //===----------------------------------------------------------------------===//
787 // ConversionPatternRewriterImpl
788 //===----------------------------------------------------------------------===//
789 namespace mlir {
790 namespace detail {
793  const ConversionConfig &config)
794  : context(ctx), config(config) {}
795 
796  //===--------------------------------------------------------------------===//
797  // State Management
798  //===--------------------------------------------------------------------===//
799 
800  /// Return the current state of the rewriter.
801  RewriterState getCurrentState();
802 
803  /// Apply all requested operation rewrites. This method is invoked when the
804  /// conversion process succeeds.
805  void applyRewrites();
806 
807  /// Reset the state of the rewriter to a previously saved point.
808  void resetState(RewriterState state);
809 
810  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
811  /// failure.
812  template <typename RewriteTy, typename... Args>
813  void appendRewrite(Args &&...args) {
814  rewrites.push_back(
815  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
816  }
817 
818  /// Undo the rewrites (motions, splits) one by one in reverse order until
819  /// "numRewritesToKeep" rewrites remains.
820  void undoRewrites(unsigned numRewritesToKeep = 0);
821 
822  /// Remap the given values to those with potentially different types. Returns
823  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
824  /// is the tag used when describing a value within a diagnostic, e.g.
825  /// "operand".
826  LogicalResult remapValues(StringRef valueDiagTag,
827  std::optional<Location> inputLoc,
828  PatternRewriter &rewriter, ValueRange values,
829  SmallVectorImpl<Value> &remapped);
830 
831  /// Return "true" if the given operation is ignored, and does not need to be
832  /// converted.
833  bool isOpIgnored(Operation *op) const;
834 
835  /// Return "true" if the given operation was replaced or erased.
836  bool wasOpReplaced(Operation *op) const;
837 
838  //===--------------------------------------------------------------------===//
839  // Type Conversion
840  //===--------------------------------------------------------------------===//
841 
842  /// Attempt to convert the signature of the given block, if successful a new
843  /// block is returned containing the new arguments. Returns `block` if it did
844  /// not require conversion.
845  FailureOr<Block *> convertBlockSignature(
846  ConversionPatternRewriter &rewriter, Block *block,
847  const TypeConverter *converter,
848  TypeConverter::SignatureConversion *conversion = nullptr);
849 
850  /// Convert the types of non-entry block arguments within the given region.
851  LogicalResult convertNonEntryRegionTypes(
852  ConversionPatternRewriter &rewriter, Region *region,
853  const TypeConverter &converter,
854  ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
855 
856  /// Apply a signature conversion on the given region, using `converter` for
857  /// materializations if not null.
858  Block *
859  applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
861  const TypeConverter *converter);
862 
863  /// Convert the types of block arguments within the given region.
865  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
866  const TypeConverter &converter,
867  TypeConverter::SignatureConversion *entryConversion);
868 
869  /// Apply the given signature conversion on the given block. The new block
870  /// containing the updated signature is returned. If no conversions were
871  /// necessary, e.g. if the block has no arguments, `block` is returned.
872  /// `converter` is used to generate any necessary cast operations that
873  /// translate between the origin argument types and those specified in the
874  /// signature conversion.
875  Block *applySignatureConversion(
876  ConversionPatternRewriter &rewriter, Block *block,
877  const TypeConverter *converter,
878  TypeConverter::SignatureConversion &signatureConversion);
879 
880  //===--------------------------------------------------------------------===//
881  // Materializations
882  //===--------------------------------------------------------------------===//
883  /// Build an unresolved materialization operation given an output type and set
884  /// of input operands.
885  Value buildUnresolvedMaterialization(MaterializationKind kind,
886  Block *insertBlock,
887  Block::iterator insertPt, Location loc,
888  ValueRange inputs, Type outputType,
889  Type origOutputType,
890  const TypeConverter *converter);
891 
892  Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
893  ValueRange inputs,
894  Type origOutputType,
895  Type outputType,
896  const TypeConverter *converter);
897 
898  Value buildUnresolvedTargetMaterialization(Location loc, Value input,
899  Type outputType,
900  const TypeConverter *converter);
901 
902  //===--------------------------------------------------------------------===//
903  // Rewriter Notification Hooks
904  //===--------------------------------------------------------------------===//
905 
906  //// Notifies that an op was inserted.
907  void notifyOperationInserted(Operation *op,
908  OpBuilder::InsertPoint previous) override;
909 
910  /// Notifies that an op is about to be replaced with the given values.
911  void notifyOpReplaced(Operation *op, ValueRange newValues);
912 
913  /// Notifies that a block is about to be erased.
914  void notifyBlockIsBeingErased(Block *block);
915 
916  /// Notifies that a block was inserted.
917  void notifyBlockInserted(Block *block, Region *previous,
918  Region::iterator previousIt) override;
919 
920  /// Notifies that a block is being inlined into another block.
921  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
922  Block::iterator before);
923 
924  /// Notifies that a pattern match failed for the given reason.
925  void
926  notifyMatchFailure(Location loc,
927  function_ref<void(Diagnostic &)> reasonCallback) override;
928 
929  //===--------------------------------------------------------------------===//
930  // IR Erasure
931  //===--------------------------------------------------------------------===//
932 
933  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
934  /// operation or block is erased multiple times. This rewriter assumes that
935  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
937  public:
939  : RewriterBase(context, /*listener=*/this) {}
940 
941  /// Erase the given op (unless it was already erased).
942  void eraseOp(Operation *op) override {
943  if (erased.contains(op))
944  return;
945  op->dropAllUses();
947  }
948 
949  /// Erase the given block (unless it was already erased).
950  void eraseBlock(Block *block) override {
951  if (erased.contains(block))
952  return;
953  assert(block->empty() && "expected empty block");
954  block->dropAllDefinedValueUses();
956  }
957 
958  void notifyOperationErased(Operation *op) override { erased.insert(op); }
959 
960  void notifyBlockErased(Block *block) override { erased.insert(block); }
961 
962  /// Pointers to all erased operations and blocks.
964  };
965 
966  //===--------------------------------------------------------------------===//
967  // State
968  //===--------------------------------------------------------------------===//
969 
970  /// MLIR context.
972 
973  // Mapping between replaced values that differ in type. This happens when
974  // replacing a value with one of a different type.
975  ConversionValueMapping mapping;
976 
977  /// Ordered list of block operations (creations, splits, motions).
979 
980  /// A set of operations that should no longer be considered for legalization.
981  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
982  /// tracked separately.
984 
985  /// A set of operations that were replaced/erased. Such ops are not erased
986  /// immediately but only when the dialect conversion succeeds. In the mean
987  /// time, they should no longer be considered for legalization and any attempt
988  /// to modify/access them is invalid rewriter API usage.
990 
991  /// The current type converter, or nullptr if no type converter is currently
992  /// active.
993  const TypeConverter *currentTypeConverter = nullptr;
994 
995  /// A mapping of regions to type converters that should be used when
996  /// converting the arguments of blocks within that region.
998 
999  /// Dialect conversion configuration.
1001 
1002 #ifndef NDEBUG
1003  /// A set of operations that have pending updates. This tracking isn't
1004  /// strictly necessary, and is thus only active during debug builds for extra
1005  /// verification.
1007 
1008  /// A logger used to emit diagnostics during the conversion process.
1009  llvm::ScopedPrinter logger{llvm::dbgs()};
1010 #endif
1011 };
1012 } // namespace detail
1013 } // namespace mlir
1014 
1015 const ConversionConfig &IRRewrite::getConfig() const {
1016  return rewriterImpl.config;
1017 }
1018 
1019 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1020  // Inform the listener about all IR modifications that have already taken
1021  // place: References to the original block have been replaced with the new
1022  // block.
1023  if (auto *listener =
1024  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1025  for (Operation *op : block->getUsers())
1026  listener->notifyOperationModified(op);
1027 
1028  // Process the remapping for each of the original arguments.
1029  for (auto [origArg, info] :
1030  llvm::zip_equal(origBlock->getArguments(), argInfo)) {
1031  // Handle the case of a 1->0 value mapping.
1032  if (!info) {
1033  if (Value newArg =
1034  rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
1035  rewriter.replaceAllUsesWith(origArg, newArg);
1036  continue;
1037  }
1038 
1039  // Otherwise this is a 1->1+ value mapping.
1040  Value castValue = info->castValue;
1041  assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
1042 
1043  // If the argument is still used, replace it with the generated cast.
1044  if (!origArg.use_empty()) {
1045  rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
1046  castValue, origArg.getType()));
1047  }
1048  }
1049 }
1050 
1051 void BlockTypeConversionRewrite::rollback() {
1052  block->replaceAllUsesWith(origBlock);
1053 }
1054 
1055 LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1056  function_ref<Operation *(Value)> findLiveUser) {
1057  // Process the remapping for each of the original arguments.
1058  for (auto it : llvm::enumerate(origBlock->getArguments())) {
1059  BlockArgument origArg = it.value();
1060  // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
1061  OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
1062  builder.setInsertionPointToStart(block);
1063 
1064  // If the type of this argument changed and the argument is still live, we
1065  // need to materialize a conversion.
1066  if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
1067  continue;
1068  Operation *liveUser = findLiveUser(origArg);
1069  if (!liveUser)
1070  continue;
1071 
1072  Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
1073  bool isDroppedArg = replacementValue == origArg;
1074  if (!isDroppedArg)
1075  builder.setInsertionPointAfterValue(replacementValue);
1076  Value newArg;
1077  if (converter) {
1078  newArg = converter->materializeSourceConversion(
1079  builder, origArg.getLoc(), origArg.getType(),
1080  isDroppedArg ? ValueRange() : ValueRange(replacementValue));
1081  assert((!newArg || newArg.getType() == origArg.getType()) &&
1082  "materialization hook did not provide a value of the expected "
1083  "type");
1084  }
1085  if (!newArg) {
1087  emitError(origArg.getLoc())
1088  << "failed to materialize conversion for block argument #"
1089  << it.index() << " that remained live after conversion, type was "
1090  << origArg.getType();
1091  if (!isDroppedArg)
1092  diag << ", with target type " << replacementValue.getType();
1093  diag.attachNote(liveUser->getLoc())
1094  << "see existing live user here: " << *liveUser;
1095  return failure();
1096  }
1097  rewriterImpl.mapping.map(origArg, newArg);
1098  }
1099  return success();
1100 }
1101 
1102 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1103  Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
1104  if (!repl)
1105  return;
1106 
1107  if (isa<BlockArgument>(repl)) {
1108  rewriter.replaceAllUsesWith(arg, repl);
1109  return;
1110  }
1111 
1112  // If the replacement value is an operation, we check to make sure that we
1113  // don't replace uses that are within the parent operation of the
1114  // replacement value.
1115  Operation *replOp = cast<OpResult>(repl).getOwner();
1116  Block *replBlock = replOp->getBlock();
1117  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1118  Operation *user = operand.getOwner();
1119  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1120  });
1121 }
1122 
1123 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
1124 
1125 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1126  auto *listener =
1127  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1128 
1129  // Compute replacement values.
1130  SmallVector<Value> replacements =
1131  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1132  return rewriterImpl.mapping.lookupOrNull(result, result.getType());
1133  });
1134 
1135  // Notify the listener that the operation is about to be replaced.
1136  if (listener)
1137  listener->notifyOperationReplaced(op, replacements);
1138 
1139  // Replace all uses with the new values.
1140  for (auto [result, newValue] :
1141  llvm::zip_equal(op->getResults(), replacements))
1142  if (newValue)
1143  rewriter.replaceAllUsesWith(result, newValue);
1144 
1145  // The original op will be erased, so remove it from the set of unlegalized
1146  // ops.
1147  if (getConfig().unlegalizedOps)
1148  getConfig().unlegalizedOps->erase(op);
1149 
1150  // Notify the listener that the operation (and its nested operations) was
1151  // erased.
1152  if (listener) {
1154  [&](Operation *op) { listener->notifyOperationErased(op); });
1155  }
1156 
1157  // Do not erase the operation yet. It may still be referenced in `mapping`.
1158  // Just unlink it for now and erase it during cleanup.
1159  op->getBlock()->getOperations().remove(op);
1160 }
1161 
1162 void ReplaceOperationRewrite::rollback() {
1163  for (auto result : op->getResults())
1164  rewriterImpl.mapping.erase(result);
1165 }
1166 
1167 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1168  rewriter.eraseOp(op);
1169 }
1170 
1171 void CreateOperationRewrite::rollback() {
1172  for (Region &region : op->getRegions()) {
1173  while (!region.getBlocks().empty())
1174  region.getBlocks().remove(region.getBlocks().begin());
1175  }
1176  op->dropAllUses();
1177  op->erase();
1178 }
1179 
1180 void UnresolvedMaterializationRewrite::rollback() {
1181  if (getMaterializationKind() == MaterializationKind::Target) {
1182  for (Value input : op->getOperands())
1183  rewriterImpl.mapping.erase(input);
1184  }
1185  op->erase();
1186 }
1187 
1188 void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
1189  rewriter.eraseOp(op);
1190 }
1191 
1193  // Commit all rewrites.
1194  IRRewriter rewriter(context, config.listener);
1195  for (auto &rewrite : rewrites)
1196  rewrite->commit(rewriter);
1197 
1198  // Clean up all rewrites.
1199  SingleEraseRewriter eraseRewriter(context);
1200  for (auto &rewrite : rewrites)
1201  rewrite->cleanup(eraseRewriter);
1202 }
1203 
1204 //===----------------------------------------------------------------------===//
1205 // State Management
1206 
1208  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1209 }
1210 
1212  // Undo any rewrites.
1213  undoRewrites(state.numRewrites);
1214 
1215  // Pop all of the recorded ignored operations that are no longer valid.
1216  while (ignoredOps.size() != state.numIgnoredOperations)
1217  ignoredOps.pop_back();
1218 
1219  while (replacedOps.size() != state.numReplacedOps)
1220  replacedOps.pop_back();
1221 }
1222 
1223 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1224  for (auto &rewrite :
1225  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1226  rewrite->rollback();
1227  rewrites.resize(numRewritesToKeep);
1228 }
1229 
1231  StringRef valueDiagTag, std::optional<Location> inputLoc,
1232  PatternRewriter &rewriter, ValueRange values,
1233  SmallVectorImpl<Value> &remapped) {
1234  remapped.reserve(llvm::size(values));
1235 
1236  SmallVector<Type, 1> legalTypes;
1237  for (const auto &it : llvm::enumerate(values)) {
1238  Value operand = it.value();
1239  Type origType = operand.getType();
1240 
1241  // If a converter was provided, get the desired legal types for this
1242  // operand.
1243  Type desiredType;
1244  if (currentTypeConverter) {
1245  // If there is no legal conversion, fail to match this pattern.
1246  legalTypes.clear();
1247  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1248  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1249  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1250  diag << "unable to convert type for " << valueDiagTag << " #"
1251  << it.index() << ", type was " << origType;
1252  });
1253  return failure();
1254  }
1255  // TODO: There currently isn't any mechanism to do 1->N type conversion
1256  // via the PatternRewriter replacement API, so for now we just ignore it.
1257  if (legalTypes.size() == 1)
1258  desiredType = legalTypes.front();
1259  } else {
1260  // TODO: What we should do here is just set `desiredType` to `origType`
1261  // and then handle the necessary type conversions after the conversion
1262  // process has finished. Unfortunately a lot of patterns currently rely on
1263  // receiving the new operands even if the types change, so we keep the
1264  // original behavior here for now until all of the patterns relying on
1265  // this get updated.
1266  }
1267  Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1268 
1269  // Handle the case where the conversion was 1->1 and the new operand type
1270  // isn't legal.
1271  Type newOperandType = newOperand.getType();
1272  if (currentTypeConverter && desiredType && newOperandType != desiredType) {
1273  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1275  operandLoc, newOperand, desiredType, currentTypeConverter);
1276  mapping.map(mapping.lookupOrDefault(newOperand), castValue);
1277  newOperand = castValue;
1278  }
1279  remapped.push_back(newOperand);
1280  }
1281  return success();
1282 }
1283 
1285  // Check to see if this operation is ignored or was replaced.
1286  return replacedOps.count(op) || ignoredOps.count(op);
1287 }
1288 
1290  // Check to see if this operation was replaced.
1291  return replacedOps.count(op);
1292 }
1293 
1294 //===----------------------------------------------------------------------===//
1295 // Type Conversion
1296 
1298  ConversionPatternRewriter &rewriter, Block *block,
1299  const TypeConverter *converter,
1300  TypeConverter::SignatureConversion *conversion) {
1301  if (conversion)
1302  return applySignatureConversion(rewriter, block, converter, *conversion);
1303 
1304  // If a converter wasn't provided, and the block wasn't already converted,
1305  // there is nothing we can do.
1306  if (!converter)
1307  return failure();
1308 
1309  // Try to convert the signature for the block with the provided converter.
1310  if (auto conversion = converter->convertBlockSignature(block))
1311  return applySignatureConversion(rewriter, block, converter, *conversion);
1312  return failure();
1313 }
1314 
1316  ConversionPatternRewriter &rewriter, Region *region,
1318  const TypeConverter *converter) {
1319  if (!region->empty())
1320  return *convertBlockSignature(rewriter, &region->front(), converter,
1321  &conversion);
1322  return nullptr;
1323 }
1324 
1326  ConversionPatternRewriter &rewriter, Region *region,
1327  const TypeConverter &converter,
1328  TypeConverter::SignatureConversion *entryConversion) {
1329  regionToConverter[region] = &converter;
1330  if (region->empty())
1331  return nullptr;
1332 
1333  if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
1334  return failure();
1335 
1337  rewriter, &region->front(), &converter, entryConversion);
1338  return newEntry;
1339 }
1340 
1342  ConversionPatternRewriter &rewriter, Region *region,
1343  const TypeConverter &converter,
1345  regionToConverter[region] = &converter;
1346  if (region->empty())
1347  return success();
1348 
1349  // Convert the arguments of each block within the region.
1350  int blockIdx = 0;
1351  assert((blockConversions.empty() ||
1352  blockConversions.size() == region->getBlocks().size() - 1) &&
1353  "expected either to provide no SignatureConversions at all or to "
1354  "provide a SignatureConversion for each non-entry block");
1355 
1356  for (Block &block :
1357  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1358  TypeConverter::SignatureConversion *blockConversion =
1359  blockConversions.empty()
1360  ? nullptr
1361  : const_cast<TypeConverter::SignatureConversion *>(
1362  &blockConversions[blockIdx++]);
1363 
1364  if (failed(convertBlockSignature(rewriter, &block, &converter,
1365  blockConversion)))
1366  return failure();
1367  }
1368  return success();
1369 }
1370 
1372  ConversionPatternRewriter &rewriter, Block *block,
1373  const TypeConverter *converter,
1374  TypeConverter::SignatureConversion &signatureConversion) {
1375  OpBuilder::InsertionGuard g(rewriter);
1376 
1377  // If no arguments are being changed or added, there is nothing to do.
1378  unsigned origArgCount = block->getNumArguments();
1379  auto convertedTypes = signatureConversion.getConvertedTypes();
1380  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1381  return block;
1382 
1383  // Compute the locations of all block arguments in the new block.
1384  SmallVector<Location> newLocs(convertedTypes.size(),
1385  rewriter.getUnknownLoc());
1386  for (unsigned i = 0; i < origArgCount; ++i) {
1387  auto inputMap = signatureConversion.getInputMapping(i);
1388  if (!inputMap || inputMap->replacementValue)
1389  continue;
1390  Location origLoc = block->getArgument(i).getLoc();
1391  for (unsigned j = 0; j < inputMap->size; ++j)
1392  newLocs[inputMap->inputNo + j] = origLoc;
1393  }
1394 
1395  // Insert a new block with the converted block argument types and move all ops
1396  // from the old block to the new block.
1397  Block *newBlock =
1398  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1399  convertedTypes, newLocs);
1400 
1401  // If a listener is attached to the dialect conversion, ops cannot be moved
1402  // to the destination block in bulk ("fast path"). This is because at the time
1403  // the notifications are sent, it is unknown which ops were moved. Instead,
1404  // ops should be moved one-by-one ("slow path"), so that a separate
1405  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1406  // a bit more efficient, so we try to do that when possible.
1407  bool fastPath = !config.listener;
1408  if (fastPath) {
1409  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1410  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1411  } else {
1412  while (!block->empty())
1413  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1414  }
1415 
1416  // Replace all uses of the old block with the new block.
1417  block->replaceAllUsesWith(newBlock);
1418 
1419  // Remap each of the original arguments as determined by the signature
1420  // conversion.
1422  argInfo.resize(origArgCount);
1423 
1424  for (unsigned i = 0; i != origArgCount; ++i) {
1425  auto inputMap = signatureConversion.getInputMapping(i);
1426  if (!inputMap)
1427  continue;
1428  BlockArgument origArg = block->getArgument(i);
1429 
1430  // If inputMap->replacementValue is not nullptr, then the argument is
1431  // dropped and a replacement value is provided to be the remappedValue.
1432  if (inputMap->replacementValue) {
1433  assert(inputMap->size == 0 &&
1434  "invalid to provide a replacement value when the argument isn't "
1435  "dropped");
1436  mapping.map(origArg, inputMap->replacementValue);
1437  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1438  continue;
1439  }
1440 
1441  // Otherwise, this is a 1->1+ mapping.
1442  auto replArgs =
1443  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1444  Value newArg;
1445 
1446  // If this is a 1->1 mapping and the types of new and replacement arguments
1447  // match (i.e. it's an identity map), then the argument is mapped to its
1448  // original type.
1449  // FIXME: We simply pass through the replacement argument if there wasn't a
1450  // converter, which isn't great as it allows implicit type conversions to
1451  // appear. We should properly restructure this code to handle cases where a
1452  // converter isn't provided and also to properly handle the case where an
1453  // argument materialization is actually a temporary source materialization
1454  // (e.g. in the case of 1->N).
1455  if (replArgs.size() == 1 &&
1456  (!converter || replArgs[0].getType() == origArg.getType())) {
1457  newArg = replArgs.front();
1458  } else {
1459  Type origOutputType = origArg.getType();
1460 
1461  // Legalize the argument output type.
1462  Type outputType = origOutputType;
1463  if (Type legalOutputType = converter->convertType(outputType))
1464  outputType = legalOutputType;
1465 
1467  newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
1468  converter);
1469  }
1470 
1471  mapping.map(origArg, newArg);
1472  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1473  argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
1474  }
1475 
1476  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1477  converter);
1478 
1479  // Erase the old block. (It is just unlinked for now and will be erased during
1480  // cleanup.)
1481  rewriter.eraseBlock(block);
1482 
1483  return newBlock;
1484 }
1485 
1486 //===----------------------------------------------------------------------===//
1487 // Materializations
1488 //===----------------------------------------------------------------------===//
1489 
1490 /// Build an unresolved materialization operation given an output type and set
1491 /// of input operands.
1493  MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1494  Location loc, ValueRange inputs, Type outputType, Type origOutputType,
1495  const TypeConverter *converter) {
1496  // Avoid materializing an unnecessary cast.
1497  if (inputs.size() == 1 && inputs.front().getType() == outputType)
1498  return inputs.front();
1499 
1500  // Create an unresolved materialization. We use a new OpBuilder to avoid
1501  // tracking the materialization like we do for other operations.
1502  OpBuilder builder(insertBlock, insertPt);
1503  auto convertOp =
1504  builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1505  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1506  origOutputType);
1507  return convertOp.getResult(0);
1508 }
1510  Block *block, Location loc, ValueRange inputs, Type origOutputType,
1511  Type outputType, const TypeConverter *converter) {
1513  block->begin(), loc, inputs, outputType,
1514  origOutputType, converter);
1515 }
1517  Location loc, Value input, Type outputType,
1518  const TypeConverter *converter) {
1519  Block *insertBlock = input.getParentBlock();
1520  Block::iterator insertPt = insertBlock->begin();
1521  if (OpResult inputRes = dyn_cast<OpResult>(input))
1522  insertPt = ++inputRes.getOwner()->getIterator();
1523 
1524  return buildUnresolvedMaterialization(MaterializationKind::Target,
1525  insertBlock, insertPt, loc, input,
1526  outputType, outputType, converter);
1527 }
1528 
1529 //===----------------------------------------------------------------------===//
1530 // Rewriter Notification Hooks
1531 
1533  Operation *op, OpBuilder::InsertPoint previous) {
1534  LLVM_DEBUG({
1535  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1536  << ")\n";
1537  });
1538  assert(!wasOpReplaced(op->getParentOp()) &&
1539  "attempting to insert into a block within a replaced/erased op");
1540 
1541  if (!previous.isSet()) {
1542  // This is a newly created op.
1543  appendRewrite<CreateOperationRewrite>(op);
1544  return;
1545  }
1546  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1547  ? nullptr
1548  : &*previous.getPoint();
1549  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1550 }
1551 
1553  ValueRange newValues) {
1554  assert(newValues.size() == op->getNumResults());
1555  assert(!ignoredOps.contains(op) && "operation was already replaced");
1556 
1557  // Track if any of the results changed, e.g. erased and replaced with null.
1558  bool resultChanged = false;
1559 
1560  // Create mappings for each of the new result values.
1561  for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
1562  if (!newValue) {
1563  resultChanged = true;
1564  continue;
1565  }
1566  // Remap, and check for any result type changes.
1567  mapping.map(result, newValue);
1568  resultChanged |= (newValue.getType() != result.getType());
1569  }
1570 
1571  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
1572  resultChanged);
1573 
1574  // Mark this operation and all nested ops as replaced.
1575  op->walk([&](Operation *op) { replacedOps.insert(op); });
1576 }
1577 
1579  Region *region = block->getParent();
1580  Block *origNextBlock = block->getNextNode();
1581  appendRewrite<EraseBlockRewrite>(block, region, origNextBlock);
1582 }
1583 
1585  Block *block, Region *previous, Region::iterator previousIt) {
1586  assert(!wasOpReplaced(block->getParentOp()) &&
1587  "attempting to insert into a region within a replaced/erased op");
1588  LLVM_DEBUG(
1589  {
1590  Operation *parent = block->getParentOp();
1591  if (parent) {
1592  logger.startLine() << "** Insert Block into : '" << parent->getName()
1593  << "'(" << parent << ")\n";
1594  } else {
1595  logger.startLine()
1596  << "** Insert Block into detached Region (nullptr parent op)'";
1597  }
1598  });
1599 
1600  if (!previous) {
1601  // This is a newly created block.
1602  appendRewrite<CreateBlockRewrite>(block);
1603  return;
1604  }
1605  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1606  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1607 }
1608 
1610  Block *block, Block *srcBlock, Block::iterator before) {
1611  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1612 }
1613 
1615  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1616  LLVM_DEBUG({
1618  reasonCallback(diag);
1619  logger.startLine() << "** Failure : " << diag.str() << "\n";
1620  if (config.notifyCallback)
1622  });
1623 }
1624 
1625 //===----------------------------------------------------------------------===//
1626 // ConversionPatternRewriter
1627 //===----------------------------------------------------------------------===//
1628 
1629 ConversionPatternRewriter::ConversionPatternRewriter(
1630  MLIRContext *ctx, const ConversionConfig &config)
1631  : PatternRewriter(ctx),
1632  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1633  setListener(impl.get());
1634 }
1635 
1637 
1639  assert(op && newOp && "expected non-null op");
1640  replaceOp(op, newOp->getResults());
1641 }
1642 
1644  assert(op->getNumResults() == newValues.size() &&
1645  "incorrect # of replacement values");
1646  LLVM_DEBUG({
1647  impl->logger.startLine()
1648  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1649  });
1650  impl->notifyOpReplaced(op, newValues);
1651 }
1652 
1654  LLVM_DEBUG({
1655  impl->logger.startLine()
1656  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1657  });
1658  SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1659  impl->notifyOpReplaced(op, nullRepls);
1660 }
1661 
1663  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1664  "attempting to erase a block within a replaced/erased op");
1665 
1666  // Mark all ops for erasure.
1667  for (Operation &op : *block)
1668  eraseOp(&op);
1669 
1670  // Unlink the block from its parent region. The block is kept in the rewrite
1671  // object and will be actually destroyed when rewrites are applied. This
1672  // allows us to keep the operations in the block live and undo the removal by
1673  // re-inserting the block.
1674  impl->notifyBlockIsBeingErased(block);
1675  block->getParent()->getBlocks().remove(block);
1676 }
1677 
1679  Region *region, TypeConverter::SignatureConversion &conversion,
1680  const TypeConverter *converter) {
1681  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1682  "attempting to apply a signature conversion to a block within a "
1683  "replaced/erased op");
1684  return impl->applySignatureConversion(*this, region, conversion, converter);
1685 }
1686 
1688  Region *region, const TypeConverter &converter,
1689  TypeConverter::SignatureConversion *entryConversion) {
1690  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1691  "attempting to apply a signature conversion to a block within a "
1692  "replaced/erased op");
1693  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1694 }
1695 
1697  Region *region, const TypeConverter &converter,
1699  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1700  "attempting to apply a signature conversion to a block within a "
1701  "replaced/erased op");
1702  return impl->convertNonEntryRegionTypes(*this, region, converter,
1703  blockConversions);
1704 }
1705 
1707  Value to) {
1708  LLVM_DEBUG({
1709  Operation *parentOp = from.getOwner()->getParentOp();
1710  impl->logger.startLine() << "** Replace Argument : '" << from
1711  << "'(in region of '" << parentOp->getName()
1712  << "'(" << from.getOwner()->getParentOp() << ")\n";
1713  });
1714  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
1715  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1716 }
1717 
1719  SmallVector<Value> remappedValues;
1720  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1721  remappedValues)))
1722  return nullptr;
1723  return remappedValues.front();
1724 }
1725 
1728  SmallVectorImpl<Value> &results) {
1729  if (keys.empty())
1730  return success();
1731  return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1732  results);
1733 }
1734 
1736  Block::iterator before,
1737  ValueRange argValues) {
1738 #ifndef NDEBUG
1739  assert(argValues.size() == source->getNumArguments() &&
1740  "incorrect # of argument replacement values");
1741  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1742  "attempting to inline a block from a replaced/erased op");
1743  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1744  "attempting to inline a block into a replaced/erased op");
1745  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1746  // The source block will be deleted, so it should not have any users (i.e.,
1747  // there should be no predecessors).
1748  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1749  "expected 'source' to have no predecessors");
1750 #endif // NDEBUG
1751 
1752  // If a listener is attached to the dialect conversion, ops cannot be moved
1753  // to the destination block in bulk ("fast path"). This is because at the time
1754  // the notifications are sent, it is unknown which ops were moved. Instead,
1755  // ops should be moved one-by-one ("slow path"), so that a separate
1756  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1757  // a bit more efficient, so we try to do that when possible.
1758  bool fastPath = !impl->config.listener;
1759 
1760  if (fastPath)
1761  impl->notifyBlockBeingInlined(dest, source, before);
1762 
1763  // Replace all uses of block arguments.
1764  for (auto it : llvm::zip(source->getArguments(), argValues))
1765  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1766 
1767  if (fastPath) {
1768  // Move all ops at once.
1769  dest->getOperations().splice(before, source->getOperations());
1770  } else {
1771  // Move op by op.
1772  while (!source->empty())
1773  moveOpBefore(&source->front(), dest, before);
1774  }
1775 
1776  // Erase the source block.
1777  eraseBlock(source);
1778 }
1779 
1781  assert(!impl->wasOpReplaced(op) &&
1782  "attempting to modify a replaced/erased op");
1783 #ifndef NDEBUG
1784  impl->pendingRootUpdates.insert(op);
1785 #endif
1786  impl->appendRewrite<ModifyOperationRewrite>(op);
1787 }
1788 
1790  assert(!impl->wasOpReplaced(op) &&
1791  "attempting to modify a replaced/erased op");
1793  // There is nothing to do here, we only need to track the operation at the
1794  // start of the update.
1795 #ifndef NDEBUG
1796  assert(impl->pendingRootUpdates.erase(op) &&
1797  "operation did not have a pending in-place update");
1798 #endif
1799 }
1800 
1802 #ifndef NDEBUG
1803  assert(impl->pendingRootUpdates.erase(op) &&
1804  "operation did not have a pending in-place update");
1805 #endif
1806  // Erase the last update for this operation.
1807  auto it = llvm::find_if(
1808  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1809  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1810  return modifyRewrite && modifyRewrite->getOperation() == op;
1811  });
1812  assert(it != impl->rewrites.rend() && "no root update started on op");
1813  (*it)->rollback();
1814  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1815  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1816 }
1817 
1819  return *impl;
1820 }
1821 
1822 //===----------------------------------------------------------------------===//
1823 // ConversionPattern
1824 //===----------------------------------------------------------------------===//
1825 
1828  PatternRewriter &rewriter) const {
1829  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1830  auto &rewriterImpl = dialectRewriter.getImpl();
1831 
1832  // Track the current conversion pattern type converter in the rewriter.
1833  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1834  getTypeConverter());
1835 
1836  // Remap the operands of the operation.
1837  SmallVector<Value, 4> operands;
1838  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1839  op->getOperands(), operands))) {
1840  return failure();
1841  }
1842  return matchAndRewrite(op, operands, dialectRewriter);
1843 }
1844 
1845 //===----------------------------------------------------------------------===//
1846 // OperationLegalizer
1847 //===----------------------------------------------------------------------===//
1848 
1849 namespace {
1850 /// A set of rewrite patterns that can be used to legalize a given operation.
1851 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1852 
1853 /// This class defines a recursive operation legalizer.
1854 class OperationLegalizer {
1855 public:
1856  using LegalizationAction = ConversionTarget::LegalizationAction;
1857 
1858  OperationLegalizer(const ConversionTarget &targetInfo,
1859  const FrozenRewritePatternSet &patterns,
1860  const ConversionConfig &config);
1861 
1862  /// Returns true if the given operation is known to be illegal on the target.
1863  bool isIllegal(Operation *op) const;
1864 
1865  /// Attempt to legalize the given operation. Returns success if the operation
1866  /// was legalized, failure otherwise.
1867  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1868 
1869  /// Returns the conversion target in use by the legalizer.
1870  const ConversionTarget &getTarget() { return target; }
1871 
1872 private:
1873  /// Attempt to legalize the given operation by folding it.
1874  LogicalResult legalizeWithFold(Operation *op,
1875  ConversionPatternRewriter &rewriter);
1876 
1877  /// Attempt to legalize the given operation by applying a pattern. Returns
1878  /// success if the operation was legalized, failure otherwise.
1879  LogicalResult legalizeWithPattern(Operation *op,
1880  ConversionPatternRewriter &rewriter);
1881 
1882  /// Return true if the given pattern may be applied to the given operation,
1883  /// false otherwise.
1884  bool canApplyPattern(Operation *op, const Pattern &pattern,
1885  ConversionPatternRewriter &rewriter);
1886 
1887  /// Legalize the resultant IR after successfully applying the given pattern.
1888  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1889  ConversionPatternRewriter &rewriter,
1890  RewriterState &curState);
1891 
1892  /// Legalizes the actions registered during the execution of a pattern.
1894  legalizePatternBlockRewrites(Operation *op,
1895  ConversionPatternRewriter &rewriter,
1897  RewriterState &state, RewriterState &newState);
1898  LogicalResult legalizePatternCreatedOperations(
1900  RewriterState &state, RewriterState &newState);
1901  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1903  RewriterState &state,
1904  RewriterState &newState);
1905 
1906  //===--------------------------------------------------------------------===//
1907  // Cost Model
1908  //===--------------------------------------------------------------------===//
1909 
1910  /// Build an optimistic legalization graph given the provided patterns. This
1911  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1912  /// patterns for operations that are not directly legal, but may be
1913  /// transitively legal for the current target given the provided patterns.
1914  void buildLegalizationGraph(
1915  LegalizationPatterns &anyOpLegalizerPatterns,
1917 
1918  /// Compute the benefit of each node within the computed legalization graph.
1919  /// This orders the patterns within 'legalizerPatterns' based upon two
1920  /// criteria:
1921  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1922  /// represent the more direct mapping to the target.
1923  /// 2) When comparing patterns with the same legalization depth, prefer the
1924  /// pattern with the highest PatternBenefit. This allows for users to
1925  /// prefer specific legalizations over others.
1926  void computeLegalizationGraphBenefit(
1927  LegalizationPatterns &anyOpLegalizerPatterns,
1929 
1930  /// Compute the legalization depth when legalizing an operation of the given
1931  /// type.
1932  unsigned computeOpLegalizationDepth(
1933  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1935 
1936  /// Apply the conversion cost model to the given set of patterns, and return
1937  /// the smallest legalization depth of any of the patterns. See
1938  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1939  unsigned applyCostModelToPatterns(
1940  LegalizationPatterns &patterns,
1941  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1943 
1944  /// The current set of patterns that have been applied.
1945  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1946 
1947  /// The legalization information provided by the target.
1948  const ConversionTarget &target;
1949 
1950  /// The pattern applicator to use for conversions.
1951  PatternApplicator applicator;
1952 
1953  /// Dialect conversion configuration.
1954  const ConversionConfig &config;
1955 };
1956 } // namespace
1957 
1958 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1959  const FrozenRewritePatternSet &patterns,
1960  const ConversionConfig &config)
1961  : target(targetInfo), applicator(patterns), config(config) {
1962  // The set of patterns that can be applied to illegal operations to transform
1963  // them into legal ones.
1965  LegalizationPatterns anyOpLegalizerPatterns;
1966 
1967  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1968  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1969 }
1970 
1971 bool OperationLegalizer::isIllegal(Operation *op) const {
1972  return target.isIllegal(op);
1973 }
1974 
1976 OperationLegalizer::legalize(Operation *op,
1977  ConversionPatternRewriter &rewriter) {
1978 #ifndef NDEBUG
1979  const char *logLineComment =
1980  "//===-------------------------------------------===//\n";
1981 
1982  auto &logger = rewriter.getImpl().logger;
1983 #endif
1984  LLVM_DEBUG({
1985  logger.getOStream() << "\n";
1986  logger.startLine() << logLineComment;
1987  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1988  << op << ") {\n";
1989  logger.indent();
1990 
1991  // If the operation has no regions, just print it here.
1992  if (op->getNumRegions() == 0) {
1993  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1994  logger.getOStream() << "\n\n";
1995  }
1996  });
1997 
1998  // Check if this operation is legal on the target.
1999  if (auto legalityInfo = target.isLegal(op)) {
2000  LLVM_DEBUG({
2001  logSuccess(
2002  logger, "operation marked legal by the target{0}",
2003  legalityInfo->isRecursivelyLegal
2004  ? "; NOTE: operation is recursively legal; skipping internals"
2005  : "");
2006  logger.startLine() << logLineComment;
2007  });
2008 
2009  // If this operation is recursively legal, mark its children as ignored so
2010  // that we don't consider them for legalization.
2011  if (legalityInfo->isRecursivelyLegal) {
2012  op->walk([&](Operation *nested) {
2013  if (op != nested)
2014  rewriter.getImpl().ignoredOps.insert(nested);
2015  });
2016  }
2017 
2018  return success();
2019  }
2020 
2021  // Check to see if the operation is ignored and doesn't need to be converted.
2022  if (rewriter.getImpl().isOpIgnored(op)) {
2023  LLVM_DEBUG({
2024  logSuccess(logger, "operation marked 'ignored' during conversion");
2025  logger.startLine() << logLineComment;
2026  });
2027  return success();
2028  }
2029 
2030  // If the operation isn't legal, try to fold it in-place.
2031  // TODO: Should we always try to do this, even if the op is
2032  // already legal?
2033  if (succeeded(legalizeWithFold(op, rewriter))) {
2034  LLVM_DEBUG({
2035  logSuccess(logger, "operation was folded");
2036  logger.startLine() << logLineComment;
2037  });
2038  return success();
2039  }
2040 
2041  // Otherwise, we need to apply a legalization pattern to this operation.
2042  if (succeeded(legalizeWithPattern(op, rewriter))) {
2043  LLVM_DEBUG({
2044  logSuccess(logger, "");
2045  logger.startLine() << logLineComment;
2046  });
2047  return success();
2048  }
2049 
2050  LLVM_DEBUG({
2051  logFailure(logger, "no matched legalization pattern");
2052  logger.startLine() << logLineComment;
2053  });
2054  return failure();
2055 }
2056 
2058 OperationLegalizer::legalizeWithFold(Operation *op,
2059  ConversionPatternRewriter &rewriter) {
2060  auto &rewriterImpl = rewriter.getImpl();
2061  RewriterState curState = rewriterImpl.getCurrentState();
2062 
2063  LLVM_DEBUG({
2064  rewriterImpl.logger.startLine() << "* Fold {\n";
2065  rewriterImpl.logger.indent();
2066  });
2067 
2068  // Try to fold the operation.
2069  SmallVector<Value, 2> replacementValues;
2070  rewriter.setInsertionPoint(op);
2071  if (failed(rewriter.tryFold(op, replacementValues))) {
2072  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2073  return failure();
2074  }
2075  // An empty list of replacement values indicates that the fold was in-place.
2076  // As the operation changed, a new legalization needs to be attempted.
2077  if (replacementValues.empty())
2078  return legalize(op, rewriter);
2079 
2080  // Insert a replacement for 'op' with the folded replacement values.
2081  rewriter.replaceOp(op, replacementValues);
2082 
2083  // Recursively legalize any new constant operations.
2084  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
2085  i != e; ++i) {
2086  auto *createOp =
2087  dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
2088  if (!createOp)
2089  continue;
2090  if (failed(legalize(createOp->getOperation(), rewriter))) {
2091  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2092  "failed to legalize generated constant '{0}'",
2093  createOp->getOperation()->getName()));
2094  rewriterImpl.resetState(curState);
2095  return failure();
2096  }
2097  }
2098 
2099  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2100  return success();
2101 }
2102 
2104 OperationLegalizer::legalizeWithPattern(Operation *op,
2105  ConversionPatternRewriter &rewriter) {
2106  auto &rewriterImpl = rewriter.getImpl();
2107 
2108  // Functor that returns if the given pattern may be applied.
2109  auto canApply = [&](const Pattern &pattern) {
2110  bool canApply = canApplyPattern(op, pattern, rewriter);
2111  if (canApply && config.listener)
2112  config.listener->notifyPatternBegin(pattern, op);
2113  return canApply;
2114  };
2115 
2116  // Functor that cleans up the rewriter state after a pattern failed to match.
2117  RewriterState curState = rewriterImpl.getCurrentState();
2118  auto onFailure = [&](const Pattern &pattern) {
2119  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2120  LLVM_DEBUG({
2121  logFailure(rewriterImpl.logger, "pattern failed to match");
2122  if (rewriterImpl.config.notifyCallback) {
2124  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2125  << "\" on op:\n"
2126  << *op;
2127  rewriterImpl.config.notifyCallback(diag);
2128  }
2129  });
2130  if (config.listener)
2131  config.listener->notifyPatternEnd(pattern, failure());
2132  rewriterImpl.resetState(curState);
2133  appliedPatterns.erase(&pattern);
2134  };
2135 
2136  // Functor that performs additional legalization when a pattern is
2137  // successfully applied.
2138  auto onSuccess = [&](const Pattern &pattern) {
2139  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2140  auto result = legalizePatternResult(op, pattern, rewriter, curState);
2141  appliedPatterns.erase(&pattern);
2142  if (failed(result))
2143  rewriterImpl.resetState(curState);
2144  if (config.listener)
2145  config.listener->notifyPatternEnd(pattern, result);
2146  return result;
2147  };
2148 
2149  // Try to match and rewrite a pattern on this operation.
2150  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2151  onSuccess);
2152 }
2153 
2154 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2155  ConversionPatternRewriter &rewriter) {
2156  LLVM_DEBUG({
2157  auto &os = rewriter.getImpl().logger;
2158  os.getOStream() << "\n";
2159  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2160  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2161  os.getOStream() << ")' {\n";
2162  os.indent();
2163  });
2164 
2165  // Ensure that we don't cycle by not allowing the same pattern to be
2166  // applied twice in the same recursion stack if it is not known to be safe.
2167  if (!pattern.hasBoundedRewriteRecursion() &&
2168  !appliedPatterns.insert(&pattern).second) {
2169  LLVM_DEBUG(
2170  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2171  return false;
2172  }
2173  return true;
2174 }
2175 
2177 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2178  ConversionPatternRewriter &rewriter,
2179  RewriterState &curState) {
2180  auto &impl = rewriter.getImpl();
2181 
2182 #ifndef NDEBUG
2183  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2184  // Check that the root was either replaced or updated in place.
2185  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2186  auto replacedRoot = [&] {
2187  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2188  };
2189  auto updatedRootInPlace = [&] {
2190  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2191  };
2192  assert((replacedRoot() || updatedRootInPlace()) &&
2193  "expected pattern to replace the root operation");
2194 #endif // NDEBUG
2195 
2196  // Legalize each of the actions registered during application.
2197  RewriterState newState = impl.getCurrentState();
2198  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2199  newState)) ||
2200  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2201  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2202  newState))) {
2203  return failure();
2204  }
2205 
2206  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2207  return success();
2208 }
2209 
2210 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2211  Operation *op, ConversionPatternRewriter &rewriter,
2212  ConversionPatternRewriterImpl &impl, RewriterState &state,
2213  RewriterState &newState) {
2214  SmallPtrSet<Operation *, 16> operationsToIgnore;
2215 
2216  // If the pattern moved or created any blocks, make sure the types of block
2217  // arguments get legalized.
2218  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2219  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2220  if (!rewrite)
2221  continue;
2222  Block *block = rewrite->getBlock();
2223  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2224  ReplaceBlockArgRewrite>(rewrite))
2225  continue;
2226  // Only check blocks outside of the current operation.
2227  Operation *parentOp = block->getParentOp();
2228  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2229  continue;
2230 
2231  // If the region of the block has a type converter, try to convert the block
2232  // directly.
2233  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2234  if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
2235  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2236  "block"));
2237  return failure();
2238  }
2239  continue;
2240  }
2241 
2242  // Otherwise, check that this operation isn't one generated by this pattern.
2243  // This is because we will attempt to legalize the parent operation, and
2244  // blocks in regions created by this pattern will already be legalized later
2245  // on. If we haven't built the set yet, build it now.
2246  if (operationsToIgnore.empty()) {
2247  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2248  ++i) {
2249  auto *createOp =
2250  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2251  if (!createOp)
2252  continue;
2253  operationsToIgnore.insert(createOp->getOperation());
2254  }
2255  }
2256 
2257  // If this operation should be considered for re-legalization, try it.
2258  if (operationsToIgnore.insert(parentOp).second &&
2259  failed(legalize(parentOp, rewriter))) {
2260  LLVM_DEBUG(logFailure(impl.logger,
2261  "operation '{0}'({1}) became illegal after rewrite",
2262  parentOp->getName(), parentOp));
2263  return failure();
2264  }
2265  }
2266  return success();
2267 }
2268 
2269 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2271  RewriterState &state, RewriterState &newState) {
2272  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2273  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2274  if (!createOp)
2275  continue;
2276  Operation *op = createOp->getOperation();
2277  if (failed(legalize(op, rewriter))) {
2278  LLVM_DEBUG(logFailure(impl.logger,
2279  "failed to legalize generated operation '{0}'({1})",
2280  op->getName(), op));
2281  return failure();
2282  }
2283  }
2284  return success();
2285 }
2286 
2287 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2289  RewriterState &state, RewriterState &newState) {
2290  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2291  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2292  if (!rewrite)
2293  continue;
2294  Operation *op = rewrite->getOperation();
2295  if (failed(legalize(op, rewriter))) {
2296  LLVM_DEBUG(logFailure(
2297  impl.logger, "failed to legalize operation updated in-place '{0}'",
2298  op->getName()));
2299  return failure();
2300  }
2301  }
2302  return success();
2303 }
2304 
2305 //===----------------------------------------------------------------------===//
2306 // Cost Model
2307 
2308 void OperationLegalizer::buildLegalizationGraph(
2309  LegalizationPatterns &anyOpLegalizerPatterns,
2310  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2311  // A mapping between an operation and a set of operations that can be used to
2312  // generate it.
2314  // A mapping between an operation and any currently invalid patterns it has.
2316  // A worklist of patterns to consider for legality.
2317  SetVector<const Pattern *> patternWorklist;
2318 
2319  // Build the mapping from operations to the parent ops that may generate them.
2320  applicator.walkAllPatterns([&](const Pattern &pattern) {
2321  std::optional<OperationName> root = pattern.getRootKind();
2322 
2323  // If the pattern has no specific root, we can't analyze the relationship
2324  // between the root op and generated operations. Given that, add all such
2325  // patterns to the legalization set.
2326  if (!root) {
2327  anyOpLegalizerPatterns.push_back(&pattern);
2328  return;
2329  }
2330 
2331  // Skip operations that are always known to be legal.
2332  if (target.getOpAction(*root) == LegalizationAction::Legal)
2333  return;
2334 
2335  // Add this pattern to the invalid set for the root op and record this root
2336  // as a parent for any generated operations.
2337  invalidPatterns[*root].insert(&pattern);
2338  for (auto op : pattern.getGeneratedOps())
2339  parentOps[op].insert(*root);
2340 
2341  // Add this pattern to the worklist.
2342  patternWorklist.insert(&pattern);
2343  });
2344 
2345  // If there are any patterns that don't have a specific root kind, we can't
2346  // make direct assumptions about what operations will never be legalized.
2347  // Note: Technically we could, but it would require an analysis that may
2348  // recurse into itself. It would be better to perform this kind of filtering
2349  // at a higher level than here anyways.
2350  if (!anyOpLegalizerPatterns.empty()) {
2351  for (const Pattern *pattern : patternWorklist)
2352  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2353  return;
2354  }
2355 
2356  while (!patternWorklist.empty()) {
2357  auto *pattern = patternWorklist.pop_back_val();
2358 
2359  // Check to see if any of the generated operations are invalid.
2360  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2361  std::optional<LegalizationAction> action = target.getOpAction(op);
2362  return !legalizerPatterns.count(op) &&
2363  (!action || action == LegalizationAction::Illegal);
2364  }))
2365  continue;
2366 
2367  // Otherwise, if all of the generated operation are valid, this op is now
2368  // legal so add all of the child patterns to the worklist.
2369  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2370  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2371 
2372  // Add any invalid patterns of the parent operations to see if they have now
2373  // become legal.
2374  for (auto op : parentOps[*pattern->getRootKind()])
2375  patternWorklist.set_union(invalidPatterns[op]);
2376  }
2377 }
2378 
2379 void OperationLegalizer::computeLegalizationGraphBenefit(
2380  LegalizationPatterns &anyOpLegalizerPatterns,
2381  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2382  // The smallest pattern depth, when legalizing an operation.
2383  DenseMap<OperationName, unsigned> minOpPatternDepth;
2384 
2385  // For each operation that is transitively legal, compute a cost for it.
2386  for (auto &opIt : legalizerPatterns)
2387  if (!minOpPatternDepth.count(opIt.first))
2388  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2389  legalizerPatterns);
2390 
2391  // Apply the cost model to the patterns that can match any operation. Those
2392  // with a specific operation type are already resolved when computing the op
2393  // legalization depth.
2394  if (!anyOpLegalizerPatterns.empty())
2395  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2396  legalizerPatterns);
2397 
2398  // Apply a cost model to the pattern applicator. We order patterns first by
2399  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2400  // decreasing benefit.
2401  applicator.applyCostModel([&](const Pattern &pattern) {
2402  ArrayRef<const Pattern *> orderedPatternList;
2403  if (std::optional<OperationName> rootName = pattern.getRootKind())
2404  orderedPatternList = legalizerPatterns[*rootName];
2405  else
2406  orderedPatternList = anyOpLegalizerPatterns;
2407 
2408  // If the pattern is not found, then it was removed and cannot be matched.
2409  auto *it = llvm::find(orderedPatternList, &pattern);
2410  if (it == orderedPatternList.end())
2412 
2413  // Patterns found earlier in the list have higher benefit.
2414  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2415  });
2416 }
2417 
2418 unsigned OperationLegalizer::computeOpLegalizationDepth(
2419  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2420  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2421  // Check for existing depth.
2422  auto depthIt = minOpPatternDepth.find(op);
2423  if (depthIt != minOpPatternDepth.end())
2424  return depthIt->second;
2425 
2426  // If a mapping for this operation does not exist, then this operation
2427  // is always legal. Return 0 as the depth for a directly legal operation.
2428  auto opPatternsIt = legalizerPatterns.find(op);
2429  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2430  return 0u;
2431 
2432  // Record this initial depth in case we encounter this op again when
2433  // recursively computing the depth.
2434  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2435 
2436  // Apply the cost model to the operation patterns, and update the minimum
2437  // depth.
2438  unsigned minDepth = applyCostModelToPatterns(
2439  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2440  minOpPatternDepth[op] = minDepth;
2441  return minDepth;
2442 }
2443 
2444 unsigned OperationLegalizer::applyCostModelToPatterns(
2445  LegalizationPatterns &patterns,
2446  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2447  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2448  unsigned minDepth = std::numeric_limits<unsigned>::max();
2449 
2450  // Compute the depth for each pattern within the set.
2452  patternsByDepth.reserve(patterns.size());
2453  for (const Pattern *pattern : patterns) {
2454  unsigned depth = 1;
2455  for (auto generatedOp : pattern->getGeneratedOps()) {
2456  unsigned generatedOpDepth = computeOpLegalizationDepth(
2457  generatedOp, minOpPatternDepth, legalizerPatterns);
2458  depth = std::max(depth, generatedOpDepth + 1);
2459  }
2460  patternsByDepth.emplace_back(pattern, depth);
2461 
2462  // Update the minimum depth of the pattern list.
2463  minDepth = std::min(minDepth, depth);
2464  }
2465 
2466  // If the operation only has one legalization pattern, there is no need to
2467  // sort them.
2468  if (patternsByDepth.size() == 1)
2469  return minDepth;
2470 
2471  // Sort the patterns by those likely to be the most beneficial.
2472  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2473  [](const std::pair<const Pattern *, unsigned> &lhs,
2474  const std::pair<const Pattern *, unsigned> &rhs) {
2475  // First sort by the smaller pattern legalization
2476  // depth.
2477  if (lhs.second != rhs.second)
2478  return lhs.second < rhs.second;
2479 
2480  // Then sort by the larger pattern benefit.
2481  auto lhsBenefit = lhs.first->getBenefit();
2482  auto rhsBenefit = rhs.first->getBenefit();
2483  return lhsBenefit > rhsBenefit;
2484  });
2485 
2486  // Update the legalization pattern to use the new sorted list.
2487  patterns.clear();
2488  for (auto &patternIt : patternsByDepth)
2489  patterns.push_back(patternIt.first);
2490  return minDepth;
2491 }
2492 
2493 //===----------------------------------------------------------------------===//
2494 // OperationConverter
2495 //===----------------------------------------------------------------------===//
2496 namespace {
2497 enum OpConversionMode {
2498  /// In this mode, the conversion will ignore failed conversions to allow
2499  /// illegal operations to co-exist in the IR.
2500  Partial,
2501 
2502  /// In this mode, all operations must be legal for the given target for the
2503  /// conversion to succeed.
2504  Full,
2505 
2506  /// In this mode, operations are analyzed for legality. No actual rewrites are
2507  /// applied to the operations on success.
2508  Analysis,
2509 };
2510 } // namespace
2511 
2512 namespace mlir {
2513 // This class converts operations to a given conversion target via a set of
2514 // rewrite patterns. The conversion behaves differently depending on the
2515 // conversion mode.
2517  explicit OperationConverter(const ConversionTarget &target,
2518  const FrozenRewritePatternSet &patterns,
2519  const ConversionConfig &config,
2520  OpConversionMode mode)
2521  : config(config), opLegalizer(target, patterns, this->config),
2522  mode(mode) {}
2523 
2524  /// Converts the given operations to the conversion target.
2526 
2527 private:
2528  /// Converts an operation with the given rewriter.
2529  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2530 
2531  /// This method is called after the conversion process to legalize any
2532  /// remaining artifacts and complete the conversion.
2533  LogicalResult finalize(ConversionPatternRewriter &rewriter);
2534 
2535  /// Legalize the types of converted block arguments.
2537  legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2538  ConversionPatternRewriterImpl &rewriterImpl);
2539 
2540  /// Legalize any unresolved type materializations.
2541  LogicalResult legalizeUnresolvedMaterializations(
2542  ConversionPatternRewriter &rewriter,
2543  ConversionPatternRewriterImpl &rewriterImpl,
2544  std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
2545 
2546  /// Legalize an operation result that was marked as "erased".
2548  legalizeErasedResult(Operation *op, OpResult result,
2549  ConversionPatternRewriterImpl &rewriterImpl);
2550 
2551  /// Legalize an operation result that was replaced with a value of a different
2552  /// type.
2553  LogicalResult legalizeChangedResultType(
2554  Operation *op, OpResult result, Value newValue,
2555  const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2556  ConversionPatternRewriterImpl &rewriterImpl,
2557  const DenseMap<Value, SmallVector<Value>> &inverseMapping);
2558 
2559  /// Dialect conversion configuration.
2560  ConversionConfig config;
2561 
2562  /// The legalizer to use when converting operations.
2563  OperationLegalizer opLegalizer;
2564 
2565  /// The conversion mode to use when legalizing operations.
2566  OpConversionMode mode;
2567 };
2568 } // namespace mlir
2569 
2570 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2571  Operation *op) {
2572  // Legalize the given operation.
2573  if (failed(opLegalizer.legalize(op, rewriter))) {
2574  // Handle the case of a failed conversion for each of the different modes.
2575  // Full conversions expect all operations to be converted.
2576  if (mode == OpConversionMode::Full)
2577  return op->emitError()
2578  << "failed to legalize operation '" << op->getName() << "'";
2579  // Partial conversions allow conversions to fail iff the operation was not
2580  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2581  // set, non-legalizable ops are added to that set.
2582  if (mode == OpConversionMode::Partial) {
2583  if (opLegalizer.isIllegal(op))
2584  return op->emitError()
2585  << "failed to legalize operation '" << op->getName()
2586  << "' that was explicitly marked illegal";
2587  if (config.unlegalizedOps)
2588  config.unlegalizedOps->insert(op);
2589  }
2590  } else if (mode == OpConversionMode::Analysis) {
2591  // Analysis conversions don't fail if any operations fail to legalize,
2592  // they are only interested in the operations that were successfully
2593  // legalized.
2594  if (config.legalizableOps)
2595  config.legalizableOps->insert(op);
2596  }
2597  return success();
2598 }
2599 
2601  if (ops.empty())
2602  return success();
2603  const ConversionTarget &target = opLegalizer.getTarget();
2604 
2605  // Compute the set of operations and blocks to convert.
2606  SmallVector<Operation *> toConvert;
2607  for (auto *op : ops) {
2609  [&](Operation *op) {
2610  toConvert.push_back(op);
2611  // Don't check this operation's children for conversion if the
2612  // operation is recursively legal.
2613  auto legalityInfo = target.isLegal(op);
2614  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2615  return WalkResult::skip();
2616  return WalkResult::advance();
2617  });
2618  }
2619 
2620  // Convert each operation and discard rewrites on failure.
2621  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2622  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2623 
2624  for (auto *op : toConvert)
2625  if (failed(convert(rewriter, op)))
2626  return rewriterImpl.undoRewrites(), failure();
2627 
2628  // Now that all of the operations have been converted, finalize the conversion
2629  // process to ensure any lingering conversion artifacts are cleaned up and
2630  // legalized.
2631  if (failed(finalize(rewriter)))
2632  return rewriterImpl.undoRewrites(), failure();
2633 
2634  // After a successful conversion, apply rewrites if this is not an analysis
2635  // conversion.
2636  if (mode == OpConversionMode::Analysis) {
2637  rewriterImpl.undoRewrites();
2638  } else {
2639  rewriterImpl.applyRewrites();
2640  }
2641  return success();
2642 }
2643 
2645 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2646  std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
2647  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2648  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2649  inverseMapping)) ||
2650  failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2651  return failure();
2652 
2653  // Process requested operation replacements.
2654  for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
2655  auto *opReplacement =
2656  dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
2657  if (!opReplacement || !opReplacement->hasChangedResults())
2658  continue;
2659  Operation *op = opReplacement->getOperation();
2660  for (OpResult result : op->getResults()) {
2661  Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2662 
2663  // If the operation result was replaced with null, all of the uses of this
2664  // value should be replaced.
2665  if (!newValue) {
2666  if (failed(legalizeErasedResult(op, result, rewriterImpl)))
2667  return failure();
2668  continue;
2669  }
2670 
2671  // Otherwise, check to see if the type of the result changed.
2672  if (result.getType() == newValue.getType())
2673  continue;
2674 
2675  // Compute the inverse mapping only if it is really needed.
2676  if (!inverseMapping)
2677  inverseMapping = rewriterImpl.mapping.getInverse();
2678 
2679  // Legalize this result.
2680  rewriter.setInsertionPoint(op);
2681  if (failed(legalizeChangedResultType(
2682  op, result, newValue, opReplacement->getConverter(), rewriter,
2683  rewriterImpl, *inverseMapping)))
2684  return failure();
2685  }
2686  }
2687  return success();
2688 }
2689 
2690 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2691  ConversionPatternRewriter &rewriter,
2692  ConversionPatternRewriterImpl &rewriterImpl) {
2693  // Functor used to check if all users of a value will be dead after
2694  // conversion.
2695  auto findLiveUser = [&](Value val) {
2696  auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2697  return rewriterImpl.isOpIgnored(user);
2698  });
2699  return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2700  };
2701  // Note: `rewrites` may be reallocated as the loop is running.
2702  for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
2703  ++i) {
2704  auto &rewrite = rewriterImpl.rewrites[i];
2705  if (auto *blockTypeConversionRewrite =
2706  dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
2707  if (failed(blockTypeConversionRewrite->materializeLiveConversions(
2708  findLiveUser)))
2709  return failure();
2710  }
2711  return success();
2712 }
2713 
2714 /// Replace the results of a materialization operation with the given values.
2715 static void
2717  ResultRange matResults, ValueRange values,
2718  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2719  matResults.replaceAllUsesWith(values);
2720 
2721  // For each of the materialization results, update the inverse mappings to
2722  // point to the replacement values.
2723  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
2724  auto inverseMapIt = inverseMapping.find(matResult);
2725  if (inverseMapIt == inverseMapping.end())
2726  continue;
2727 
2728  // Update the reverse mapping, or remove the mapping if we couldn't update
2729  // it. Not being able to update signals that the mapping would have become
2730  // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
2731  // propagated through temporary materializations. We simply drop the
2732  // mapping, and let the post-conversion replacement logic handle updating
2733  // uses.
2734  for (Value inverseMapVal : inverseMapIt->second)
2735  if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
2736  rewriterImpl.mapping.erase(inverseMapVal);
2737  }
2738 }
2739 
2740 /// Compute all of the unresolved materializations that will persist beyond the
2741 /// conversion process, and require inserting a proper user materialization for.
2744  &materializationOps,
2745  ConversionPatternRewriter &rewriter,
2746  ConversionPatternRewriterImpl &rewriterImpl,
2747  DenseMap<Value, SmallVector<Value>> &inverseMapping,
2748  SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
2749  auto isLive = [&](Value value) {
2750  auto findFn = [&](Operation *user) {
2751  auto matIt = materializationOps.find(user);
2752  if (matIt != materializationOps.end())
2753  return !necessaryMaterializations.count(matIt->second);
2754  return rewriterImpl.isOpIgnored(user);
2755  };
2756  // This value may be replacing another value that has a live user.
2757  for (Value inv : inverseMapping.lookup(value))
2758  if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2759  return true;
2760  // Or have live users itself.
2761  return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
2762  };
2763 
2764  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
2765  [&](Value invalidRoot, Value value, Type type) {
2766  // Check to see if the input operation was remapped to a variant of the
2767  // output.
2768  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2769  if (remappedValue.getType() == type && remappedValue != invalidRoot)
2770  return remappedValue;
2771 
2772  // Check to see if the input is a materialization operation that
2773  // provides an inverse conversion. We just check blindly for
2774  // UnrealizedConversionCastOp here, but it has no effect on correctness.
2775  auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
2776  if (inputCastOp && inputCastOp->getNumOperands() == 1)
2777  return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2778  type);
2779 
2780  return Value();
2781  };
2782 
2784  for (auto &rewrite : rewriterImpl.rewrites) {
2785  auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
2786  if (!mat)
2787  continue;
2788  materializationOps.try_emplace(mat->getOperation(), mat);
2789  worklist.insert(mat);
2790  }
2791  while (!worklist.empty()) {
2792  UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
2793  UnrealizedConversionCastOp op = mat->getOperation();
2794 
2795  // We currently only handle target materializations here.
2796  assert(op->getNumResults() == 1 && "unexpected materialization type");
2797  OpResult opResult = op->getOpResult(0);
2798  Type outputType = opResult.getType();
2799  Operation::operand_range inputOperands = op.getOperands();
2800 
2801  // Try to forward propagate operands for user conversion casts that result
2802  // in the input types of the current cast.
2803  for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
2804  auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2805  if (!castOp)
2806  continue;
2807  if (castOp->getResultTypes() == inputOperands.getTypes()) {
2808  replaceMaterialization(rewriterImpl, opResult, inputOperands,
2809  inverseMapping);
2810  necessaryMaterializations.remove(materializationOps.lookup(user));
2811  }
2812  }
2813 
2814  // Try to avoid materializing a resolved materialization if possible.
2815  // Handle the case of a 1-1 materialization.
2816  if (inputOperands.size() == 1) {
2817  // Check to see if the input operation was remapped to a variant of the
2818  // output.
2819  Value remappedValue =
2820  lookupRemappedValue(opResult, inputOperands[0], outputType);
2821  if (remappedValue && remappedValue != opResult) {
2822  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2823  inverseMapping);
2824  necessaryMaterializations.remove(mat);
2825  continue;
2826  }
2827  } else {
2828  // TODO: Avoid materializing other types of conversions here.
2829  }
2830 
2831  // Check to see if this is an argument materialization.
2832  if (llvm::any_of(op->getOperands(), llvm::IsaPred<BlockArgument>) ||
2833  llvm::any_of(inverseMapping[op->getResult(0)],
2834  llvm::IsaPred<BlockArgument>)) {
2835  mat->setMaterializationKind(MaterializationKind::Argument);
2836  }
2837 
2838  // If the materialization does not have any live users, we don't need to
2839  // generate a user materialization for it.
2840  // FIXME: For argument materializations, we currently need to check if any
2841  // of the inverse mapped values are used because some patterns expect blind
2842  // value replacement even if the types differ in some cases. When those
2843  // patterns are fixed, we can drop the argument special case here.
2844  bool isMaterializationLive = isLive(opResult);
2845  if (mat->getMaterializationKind() == MaterializationKind::Argument)
2846  isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
2847  if (!isMaterializationLive)
2848  continue;
2849  if (!necessaryMaterializations.insert(mat))
2850  continue;
2851 
2852  // Reprocess input materializations to see if they have an updated status.
2853  for (Value input : inputOperands) {
2854  if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2855  if (auto *mat = materializationOps.lookup(parentOp))
2856  worklist.insert(mat);
2857  }
2858  }
2859  }
2860 }
2861 
2862 /// Legalize the given unresolved materialization. Returns success if the
2863 /// materialization was legalized, failure otherise.
2865  UnresolvedMaterializationRewrite &mat,
2867  &materializationOps,
2868  ConversionPatternRewriter &rewriter,
2869  ConversionPatternRewriterImpl &rewriterImpl,
2870  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2871  auto findLiveUser = [&](auto &&users) {
2872  auto liveUserIt = llvm::find_if_not(
2873  users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
2874  return liveUserIt == users.end() ? nullptr : *liveUserIt;
2875  };
2876 
2877  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
2878  [&](Value value, Type type) {
2879  // Check to see if the input operation was remapped to a variant of the
2880  // output.
2881  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2882  if (remappedValue.getType() == type)
2883  return remappedValue;
2884  return Value();
2885  };
2886 
2887  UnrealizedConversionCastOp op = mat.getOperation();
2888  if (!rewriterImpl.ignoredOps.insert(op))
2889  return success();
2890 
2891  // We currently only handle target materializations here.
2892  OpResult opResult = op->getOpResult(0);
2893  Operation::operand_range inputOperands = op.getOperands();
2894  Type outputType = opResult.getType();
2895 
2896  // If any input to this materialization is another materialization, resolve
2897  // the input first.
2898  for (Value value : op->getOperands()) {
2899  auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
2900  if (!valueCast)
2901  continue;
2902 
2903  auto matIt = materializationOps.find(valueCast);
2904  if (matIt != materializationOps.end())
2906  *matIt->second, materializationOps, rewriter, rewriterImpl,
2907  inverseMapping)))
2908  return failure();
2909  }
2910 
2911  // Perform a last ditch attempt to avoid materializing a resolved
2912  // materialization if possible.
2913  // Handle the case of a 1-1 materialization.
2914  if (inputOperands.size() == 1) {
2915  // Check to see if the input operation was remapped to a variant of the
2916  // output.
2917  Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2918  if (remappedValue && remappedValue != opResult) {
2919  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2920  inverseMapping);
2921  return success();
2922  }
2923  } else {
2924  // TODO: Avoid materializing other types of conversions here.
2925  }
2926 
2927  // Try to materialize the conversion.
2928  if (const TypeConverter *converter = mat.getConverter()) {
2929  // FIXME: Determine a suitable insertion location when there are multiple
2930  // inputs.
2931  if (inputOperands.size() == 1)
2932  rewriter.setInsertionPointAfterValue(inputOperands.front());
2933  else
2934  rewriter.setInsertionPoint(op);
2935 
2936  Value newMaterialization;
2937  switch (mat.getMaterializationKind()) {
2939  // Try to materialize an argument conversion.
2940  // FIXME: The current argument materialization hook expects the original
2941  // output type, even though it doesn't use that as the actual output type
2942  // of the generated IR. The output type is just used as an indicator of
2943  // the type of materialization to do. This behavior is really awkward in
2944  // that it diverges from the behavior of the other hooks, and can be
2945  // easily misunderstood. We should clean up the argument hooks to better
2946  // represent the desired invariants we actually care about.
2947  newMaterialization = converter->materializeArgumentConversion(
2948  rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
2949  if (newMaterialization)
2950  break;
2951 
2952  // If an argument materialization failed, fallback to trying a target
2953  // materialization.
2954  [[fallthrough]];
2955  case MaterializationKind::Target:
2956  newMaterialization = converter->materializeTargetConversion(
2957  rewriter, op->getLoc(), outputType, inputOperands);
2958  break;
2959  }
2960  if (newMaterialization) {
2961  replaceMaterialization(rewriterImpl, opResult, newMaterialization,
2962  inverseMapping);
2963  return success();
2964  }
2965  }
2966 
2968  << "failed to legalize unresolved materialization "
2969  "from "
2970  << inputOperands.getTypes() << " to " << outputType
2971  << " that remained live after conversion";
2972  if (Operation *liveUser = findLiveUser(op->getUsers())) {
2973  diag.attachNote(liveUser->getLoc())
2974  << "see existing live user here: " << *liveUser;
2975  }
2976  return failure();
2977 }
2978 
2979 LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2980  ConversionPatternRewriter &rewriter,
2981  ConversionPatternRewriterImpl &rewriterImpl,
2982  std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
2983  inverseMapping = rewriterImpl.mapping.getInverse();
2984 
2985  // As an initial step, compute all of the inserted materializations that we
2986  // expect to persist beyond the conversion process.
2988  SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
2989  computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
2990  *inverseMapping, necessaryMaterializations);
2991 
2992  // Once computed, legalize any necessary materializations.
2993  for (auto *mat : necessaryMaterializations) {
2995  *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
2996  return failure();
2997  }
2998  return success();
2999 }
3000 
3001 LogicalResult OperationConverter::legalizeErasedResult(
3002  Operation *op, OpResult result,
3003  ConversionPatternRewriterImpl &rewriterImpl) {
3004  // If the operation result was replaced with null, all of the uses of this
3005  // value should be replaced.
3006  auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
3007  return rewriterImpl.isOpIgnored(user);
3008  });
3009  if (liveUserIt != result.user_end()) {
3010  InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
3011  << op->getName() << "' marked as erased";
3012  diag.attachNote(liveUserIt->getLoc())
3013  << "found live user of result #" << result.getResultNumber() << ": "
3014  << *liveUserIt;
3015  return failure();
3016  }
3017  return success();
3018 }
3019 
3020 /// Finds a user of the given value, or of any other value that the given value
3021 /// replaced, that was not replaced in the conversion process.
3023  Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
3024  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
3025  SmallVector<Value> worklist(1, initialValue);
3026  while (!worklist.empty()) {
3027  Value value = worklist.pop_back_val();
3028 
3029  // Walk the users of this value to see if there are any live users that
3030  // weren't replaced during conversion.
3031  auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
3032  return rewriterImpl.isOpIgnored(user);
3033  });
3034  if (liveUserIt != value.user_end())
3035  return *liveUserIt;
3036  auto mapIt = inverseMapping.find(value);
3037  if (mapIt != inverseMapping.end())
3038  worklist.append(mapIt->second);
3039  }
3040  return nullptr;
3041 }
3042 
3043 LogicalResult OperationConverter::legalizeChangedResultType(
3044  Operation *op, OpResult result, Value newValue,
3045  const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
3046  ConversionPatternRewriterImpl &rewriterImpl,
3047  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
3048  Operation *liveUser =
3049  findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
3050  if (!liveUser)
3051  return success();
3052 
3053  // Functor used to emit a conversion error for a failed materialization.
3054  auto emitConversionError = [&] {
3056  << "failed to materialize conversion for result #"
3057  << result.getResultNumber() << " of operation '"
3058  << op->getName()
3059  << "' that remained live after conversion";
3060  diag.attachNote(liveUser->getLoc())
3061  << "see existing live user here: " << *liveUser;
3062  return failure();
3063  };
3064 
3065  // If the replacement has a type converter, attempt to materialize a
3066  // conversion back to the original type.
3067  if (!replConverter)
3068  return emitConversionError();
3069 
3070  // Materialize a conversion for this live result value.
3071  Type resultType = result.getType();
3072  Value convertedValue = replConverter->materializeSourceConversion(
3073  rewriter, op->getLoc(), resultType, newValue);
3074  if (!convertedValue)
3075  return emitConversionError();
3076 
3077  rewriterImpl.mapping.map(result, convertedValue);
3078  return success();
3079 }
3080 
3081 //===----------------------------------------------------------------------===//
3082 // Type Conversion
3083 //===----------------------------------------------------------------------===//
3084 
3086  ArrayRef<Type> types) {
3087  assert(!types.empty() && "expected valid types");
3088  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3089  addInputs(types);
3090 }
3091 
3093  assert(!types.empty() &&
3094  "1->0 type remappings don't need to be added explicitly");
3095  argTypes.append(types.begin(), types.end());
3096 }
3097 
3098 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3099  unsigned newInputNo,
3100  unsigned newInputCount) {
3101  assert(!remappedInputs[origInputNo] && "input has already been remapped");
3102  assert(newInputCount != 0 && "expected valid input count");
3103  remappedInputs[origInputNo] =
3104  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
3105 }
3106 
3108  Value replacementValue) {
3109  assert(!remappedInputs[origInputNo] && "input has already been remapped");
3110  remappedInputs[origInputNo] =
3111  InputMapping{origInputNo, /*size=*/0, replacementValue};
3112 }
3113 
3115  SmallVectorImpl<Type> &results) const {
3116  {
3117  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3118  std::defer_lock);
3120  cacheReadLock.lock();
3121  auto existingIt = cachedDirectConversions.find(t);
3122  if (existingIt != cachedDirectConversions.end()) {
3123  if (existingIt->second)
3124  results.push_back(existingIt->second);
3125  return success(existingIt->second != nullptr);
3126  }
3127  auto multiIt = cachedMultiConversions.find(t);
3128  if (multiIt != cachedMultiConversions.end()) {
3129  results.append(multiIt->second.begin(), multiIt->second.end());
3130  return success();
3131  }
3132  }
3133  // Walk the added converters in reverse order to apply the most recently
3134  // registered first.
3135  size_t currentCount = results.size();
3136 
3137  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3138  std::defer_lock);
3139 
3140  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
3141  if (std::optional<LogicalResult> result = converter(t, results)) {
3143  cacheWriteLock.lock();
3144  if (!succeeded(*result)) {
3145  cachedDirectConversions.try_emplace(t, nullptr);
3146  return failure();
3147  }
3148  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3149  if (newTypes.size() == 1)
3150  cachedDirectConversions.try_emplace(t, newTypes.front());
3151  else
3152  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3153  return success();
3154  }
3155  }
3156  return failure();
3157 }
3158 
3160  // Use the multi-type result version to convert the type.
3161  SmallVector<Type, 1> results;
3162  if (failed(convertType(t, results)))
3163  return nullptr;
3164 
3165  // Check to ensure that only one type was produced.
3166  return results.size() == 1 ? results.front() : nullptr;
3167 }
3168 
3171  SmallVectorImpl<Type> &results) const {
3172  for (Type type : types)
3173  if (failed(convertType(type, results)))
3174  return failure();
3175  return success();
3176 }
3177 
3178 bool TypeConverter::isLegal(Type type) const {
3179  return convertType(type) == type;
3180 }
3182  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
3183 }
3184 
3185 bool TypeConverter::isLegal(Region *region) const {
3186  return llvm::all_of(*region, [this](Block &block) {
3187  return isLegal(block.getArgumentTypes());
3188  });
3189 }
3190 
3191 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3192  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
3193 }
3194 
3197  SignatureConversion &result) const {
3198  // Try to convert the given input type.
3199  SmallVector<Type, 1> convertedTypes;
3200  if (failed(convertType(type, convertedTypes)))
3201  return failure();
3202 
3203  // If this argument is being dropped, there is nothing left to do.
3204  if (convertedTypes.empty())
3205  return success();
3206 
3207  // Otherwise, add the new inputs.
3208  result.addInputs(inputNo, convertedTypes);
3209  return success();
3210 }
3213  SignatureConversion &result,
3214  unsigned origInputOffset) const {
3215  for (unsigned i = 0, e = types.size(); i != e; ++i)
3216  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3217  return failure();
3218  return success();
3219 }
3220 
3221 Value TypeConverter::materializeConversion(
3222  ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
3223  Location loc, Type resultType, ValueRange inputs) const {
3224  for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
3225  if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
3226  return *result;
3227  return nullptr;
3228 }
3229 
3230 std::optional<TypeConverter::SignatureConversion>
3232  SignatureConversion conversion(block->getNumArguments());
3233  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3234  return std::nullopt;
3235  return conversion;
3236 }
3237 
3238 //===----------------------------------------------------------------------===//
3239 // Type attribute conversion
3240 //===----------------------------------------------------------------------===//
3243  return AttributeConversionResult(attr, resultTag);
3244 }
3245 
3248  return AttributeConversionResult(nullptr, naTag);
3249 }
3250 
3253  return AttributeConversionResult(nullptr, abortTag);
3254 }
3255 
3257  return impl.getInt() == resultTag;
3258 }
3259 
3261  return impl.getInt() == naTag;
3262 }
3263 
3265  return impl.getInt() == abortTag;
3266 }
3267 
3269  assert(hasResult() && "Cannot get result from N/A or abort");
3270  return impl.getPointer();
3271 }
3272 
3273 std::optional<Attribute>
3275  for (const TypeAttributeConversionCallbackFn &fn :
3276  llvm::reverse(typeAttributeConversions)) {
3277  AttributeConversionResult res = fn(type, attr);
3278  if (res.hasResult())
3279  return res.getResult();
3280  if (res.isAbort())
3281  return std::nullopt;
3282  }
3283  return std::nullopt;
3284 }
3285 
3286 //===----------------------------------------------------------------------===//
3287 // FunctionOpInterfaceSignatureConversion
3288 //===----------------------------------------------------------------------===//
3289 
3290 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3291  const TypeConverter &typeConverter,
3292  ConversionPatternRewriter &rewriter) {
3293  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3294  if (!type)
3295  return failure();
3296 
3297  // Convert the original function types.
3298  TypeConverter::SignatureConversion result(type.getNumInputs());
3299  SmallVector<Type, 1> newResults;
3300  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3301  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3302  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3303  typeConverter, &result)))
3304  return failure();
3305 
3306  // Update the function signature in-place.
3307  auto newType = FunctionType::get(rewriter.getContext(),
3308  result.getConvertedTypes(), newResults);
3309 
3310  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3311 
3312  return success();
3313 }
3314 
3315 /// Create a default conversion pattern that rewrites the type signature of a
3316 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3317 /// represent their type.
3318 namespace {
3319 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3320  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3321  MLIRContext *ctx,
3322  const TypeConverter &converter)
3323  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3324 
3326  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3327  ConversionPatternRewriter &rewriter) const override {
3328  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3329  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3330  }
3331 };
3332 
3333 struct AnyFunctionOpInterfaceSignatureConversion
3334  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3336 
3338  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3339  ConversionPatternRewriter &rewriter) const override {
3340  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3341  }
3342 };
3343 } // namespace
3344 
3347  const TypeConverter &converter,
3348  ConversionPatternRewriter &rewriter) {
3349  assert(op && "Invalid op");
3350  Location loc = op->getLoc();
3351  if (converter.isLegal(op))
3352  return rewriter.notifyMatchFailure(loc, "op already legal");
3353 
3354  OperationState newOp(loc, op->getName());
3355  newOp.addOperands(operands);
3356 
3357  SmallVector<Type> newResultTypes;
3358  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3359  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3360 
3361  newOp.addTypes(newResultTypes);
3362  newOp.addAttributes(op->getAttrs());
3363  return rewriter.create(newOp);
3364 }
3365 
3367  StringRef functionLikeOpName, RewritePatternSet &patterns,
3368  const TypeConverter &converter) {
3369  patterns.add<FunctionOpInterfaceSignatureConversion>(
3370  functionLikeOpName, patterns.getContext(), converter);
3371 }
3372 
3374  RewritePatternSet &patterns, const TypeConverter &converter) {
3375  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3376  converter, patterns.getContext());
3377 }
3378 
3379 //===----------------------------------------------------------------------===//
3380 // ConversionTarget
3381 //===----------------------------------------------------------------------===//
3382 
3384  LegalizationAction action) {
3385  legalOperations[op].action = action;
3386 }
3387 
3389  LegalizationAction action) {
3390  for (StringRef dialect : dialectNames)
3391  legalDialects[dialect] = action;
3392 }
3393 
3395  -> std::optional<LegalizationAction> {
3396  std::optional<LegalizationInfo> info = getOpInfo(op);
3397  return info ? info->action : std::optional<LegalizationAction>();
3398 }
3399 
3401  -> std::optional<LegalOpDetails> {
3402  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3403  if (!info)
3404  return std::nullopt;
3405 
3406  // Returns true if this operation instance is known to be legal.
3407  auto isOpLegal = [&] {
3408  // Handle dynamic legality either with the provided legality function.
3409  if (info->action == LegalizationAction::Dynamic) {
3410  std::optional<bool> result = info->legalityFn(op);
3411  if (result)
3412  return *result;
3413  }
3414 
3415  // Otherwise, the operation is only legal if it was marked 'Legal'.
3416  return info->action == LegalizationAction::Legal;
3417  };
3418  if (!isOpLegal())
3419  return std::nullopt;
3420 
3421  // This operation is legal, compute any additional legality information.
3422  LegalOpDetails legalityDetails;
3423  if (info->isRecursivelyLegal) {
3424  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3425  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3426  legalityDetails.isRecursivelyLegal =
3427  legalityFnIt->second(op).value_or(true);
3428  } else {
3429  legalityDetails.isRecursivelyLegal = true;
3430  }
3431  }
3432  return legalityDetails;
3433 }
3434 
3436  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3437  if (!info)
3438  return false;
3439 
3440  if (info->action == LegalizationAction::Dynamic) {
3441  std::optional<bool> result = info->legalityFn(op);
3442  if (!result)
3443  return false;
3444 
3445  return !(*result);
3446  }
3447 
3448  return info->action == LegalizationAction::Illegal;
3449 }
3450 
3454  if (!oldCallback)
3455  return newCallback;
3456 
3457  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3458  Operation *op) -> std::optional<bool> {
3459  if (std::optional<bool> result = newCl(op))
3460  return *result;
3461 
3462  return oldCl(op);
3463  };
3464  return chain;
3465 }
3466 
3467 void ConversionTarget::setLegalityCallback(
3468  OperationName name, const DynamicLegalityCallbackFn &callback) {
3469  assert(callback && "expected valid legality callback");
3470  auto *infoIt = legalOperations.find(name);
3471  assert(infoIt != legalOperations.end() &&
3472  infoIt->second.action == LegalizationAction::Dynamic &&
3473  "expected operation to already be marked as dynamically legal");
3474  infoIt->second.legalityFn =
3475  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3476 }
3477 
3479  OperationName name, const DynamicLegalityCallbackFn &callback) {
3480  auto *infoIt = legalOperations.find(name);
3481  assert(infoIt != legalOperations.end() &&
3482  infoIt->second.action != LegalizationAction::Illegal &&
3483  "expected operation to already be marked as legal");
3484  infoIt->second.isRecursivelyLegal = true;
3485  if (callback)
3486  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3487  std::move(opRecursiveLegalityFns[name]), callback);
3488  else
3489  opRecursiveLegalityFns.erase(name);
3490 }
3491 
3492 void ConversionTarget::setLegalityCallback(
3493  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3494  assert(callback && "expected valid legality callback");
3495  for (StringRef dialect : dialects)
3496  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3497  std::move(dialectLegalityFns[dialect]), callback);
3498 }
3499 
3500 void ConversionTarget::setLegalityCallback(
3501  const DynamicLegalityCallbackFn &callback) {
3502  assert(callback && "expected valid legality callback");
3503  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3504 }
3505 
3506 auto ConversionTarget::getOpInfo(OperationName op) const
3507  -> std::optional<LegalizationInfo> {
3508  // Check for info for this specific operation.
3509  const auto *it = legalOperations.find(op);
3510  if (it != legalOperations.end())
3511  return it->second;
3512  // Check for info for the parent dialect.
3513  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3514  if (dialectIt != legalDialects.end()) {
3515  DynamicLegalityCallbackFn callback;
3516  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3517  if (dialectFn != dialectLegalityFns.end())
3518  callback = dialectFn->second;
3519  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3520  callback};
3521  }
3522  // Otherwise, check if we mark unknown operations as dynamic.
3523  if (unknownLegalityFn)
3524  return LegalizationInfo{LegalizationAction::Dynamic,
3525  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3526  return std::nullopt;
3527 }
3528 
3529 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3530 //===----------------------------------------------------------------------===//
3531 // PDL Configuration
3532 //===----------------------------------------------------------------------===//
3533 
3535  auto &rewriterImpl =
3536  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3537  rewriterImpl.currentTypeConverter = getTypeConverter();
3538 }
3539 
3541  auto &rewriterImpl =
3542  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3543  rewriterImpl.currentTypeConverter = nullptr;
3544 }
3545 
3546 /// Remap the given value using the rewriter and the type converter in the
3547 /// provided config.
3550  SmallVector<Value> mappedValues;
3551  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3552  return failure();
3553  return std::move(mappedValues);
3554 }
3555 
3557  patterns.getPDLPatterns().registerRewriteFunction(
3558  "convertValue",
3559  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3560  auto results = pdllConvertValues(
3561  static_cast<ConversionPatternRewriter &>(rewriter), value);
3562  if (failed(results))
3563  return failure();
3564  return results->front();
3565  });
3566  patterns.getPDLPatterns().registerRewriteFunction(
3567  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3568  return pdllConvertValues(
3569  static_cast<ConversionPatternRewriter &>(rewriter), values);
3570  });
3571  patterns.getPDLPatterns().registerRewriteFunction(
3572  "convertType",
3573  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3574  auto &rewriterImpl =
3575  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3576  if (const TypeConverter *converter =
3577  rewriterImpl.currentTypeConverter) {
3578  if (Type newType = converter->convertType(type))
3579  return newType;
3580  return failure();
3581  }
3582  return type;
3583  });
3584  patterns.getPDLPatterns().registerRewriteFunction(
3585  "convertTypes",
3586  [](PatternRewriter &rewriter,
3588  auto &rewriterImpl =
3589  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3590  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3591  if (!converter)
3592  return SmallVector<Type>(types);
3593 
3594  SmallVector<Type> remappedTypes;
3595  if (failed(converter->convertTypes(types, remappedTypes)))
3596  return failure();
3597  return std::move(remappedTypes);
3598  });
3599 }
3600 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3601 
3602 //===----------------------------------------------------------------------===//
3603 // Op Conversion Entry Points
3604 //===----------------------------------------------------------------------===//
3605 
3606 //===----------------------------------------------------------------------===//
3607 // Partial Conversion
3608 
3610  ArrayRef<Operation *> ops, const ConversionTarget &target,
3611  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3612  OperationConverter opConverter(target, patterns, config,
3613  OpConversionMode::Partial);
3614  return opConverter.convertOperations(ops);
3615 }
3618  const FrozenRewritePatternSet &patterns,
3619  ConversionConfig config) {
3620  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3621 }
3622 
3623 //===----------------------------------------------------------------------===//
3624 // Full Conversion
3625 
3627  const ConversionTarget &target,
3628  const FrozenRewritePatternSet &patterns,
3629  ConversionConfig config) {
3630  OperationConverter opConverter(target, patterns, config,
3631  OpConversionMode::Full);
3632  return opConverter.convertOperations(ops);
3633 }
3635  const ConversionTarget &target,
3636  const FrozenRewritePatternSet &patterns,
3637  ConversionConfig config) {
3638  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3639 }
3640 
3641 //===----------------------------------------------------------------------===//
3642 // Analysis Conversion
3643 
3646  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3647  OperationConverter opConverter(target, patterns, config,
3648  OpConversionMode::Analysis);
3649  return opConverter.convertOperations(ops);
3650 }
3653  const FrozenRewritePatternSet &patterns,
3654  ConversionConfig config) {
3655  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3656 }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(UnresolvedMaterializationRewrite &mat, DenseMap< Operation *, UnresolvedMaterializationRewrite * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Legalize the given unresolved materialization.
static RewriteTy * findSingleRewrite(R &&rewrites, Block *block)
Find the single rewrite object of the specified type and block among the given rewrites.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, ResultRange matResults, ValueRange values, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Replace the results of a materialization operation with the given values.
static void computeNecessaryMaterializations(DenseMap< Operation *, UnresolvedMaterializationRewrite * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping, SetVector< UnresolvedMaterializationRewrite * > &necessaryMaterializations)
Compute all of the unresolved materializations that will persist beyond the conversion process,...
static bool hasRewrite(R &&rewrites, Operation *op)
Return "true" if there is an operation rewrite that matches the specified rewrite type and operation ...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:137
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
bool empty()
Definition: Block.h:145
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:93
OpListType & getOperations()
Definition: Block.h:134
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
iterator end()
Definition: Block.h:141
iterator begin()
Definition: Block.h:140
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
LogicalResult convertNonEntryRegionTypes(Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
Base class for the conversion patterns.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:329
Block::iterator getPoint() const
Definition: Builders.h:342
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:339
Block * getBlock() const
Definition: Builders.h:341
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:423
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:479
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
OpResult getOpResult(unsigned idx)
Definition: Operation.h:416
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:605
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:892
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:281
MLIRContext * getContext() const
Definition: PatternMatch.h:822
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
Definition: PatternMatch.h:828
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_iterator user_end() const
Definition: Value.h:227
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition: Argument.h:64
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a single value that remaps an existing signature input...
A rewriter that keeps track of erased ops and blocks.
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
DenseSet< void * > erased
Pointers to all erased operations and blocks.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc, ValueRange inputs, Type origOutputType, Type outputType, const TypeConverter *converter)
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ValueRange newValues)
Notifies that an op is about to be replaced with the given values.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
LogicalResult convertNonEntryRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions={})
Convert the types of non-entry block arguments within the given region.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter)
Apply a signature conversion on the given region, using converter for materializations if not null.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
Value buildUnresolvedMaterialization(MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, Type origOutputType, const TypeConverter *converter)
Build an unresolved materialization operation given an output type and set of input operands.
FailureOr< Block * > convertBlockSignature(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion *conversion=nullptr)
Attempt to convert the signature of the given block, if successful a new block is returned containing...
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, const TypeConverter *converter)
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.