MLIR  20.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Iterators.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::detail;
29 
30 #define DEBUG_TYPE "dialect-conversion"
31 
32 /// A utility function to log a successful result for the given reason.
33 template <typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
35  LLVM_DEBUG({
36  os.unindent();
37  os.startLine() << "} -> SUCCESS";
38  if (!fmt.empty())
39  os.getOStream() << " : "
40  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41  os.getOStream() << "\n";
42  });
43 }
44 
45 /// A utility function to log a failure result for the given reason.
46 template <typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
48  LLVM_DEBUG({
49  os.unindent();
50  os.startLine() << "} -> FAILURE : "
51  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
52  << "\n";
53  });
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // ConversionValueMapping
58 //===----------------------------------------------------------------------===//
59 
60 namespace {
61 /// This class wraps a IRMapping to provide recursive lookup
62 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
63 struct ConversionValueMapping {
64  /// Lookup a mapped value within the map. If a mapping for the provided value
65  /// does not exist then return the provided value. If `desiredType` is
66  /// non-null, returns the most recently mapped value with that type. If an
67  /// operand of that type does not exist, defaults to normal behavior.
68  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
69 
70  /// Lookup a mapped value within the map, or return null if a mapping does not
71  /// exist. If a mapping exists, this follows the same behavior of
72  /// `lookupOrDefault`.
73  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
74 
75  /// Map a value to the one provided.
76  void map(Value oldVal, Value newVal) {
77  LLVM_DEBUG({
78  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
79  assert(it != oldVal && "inserting cyclic mapping");
80  });
81  mapping.map(oldVal, newVal);
82  }
83 
84  /// Try to map a value to the one provided. Returns false if a transitive
85  /// mapping from the new value to the old value already exists, true if the
86  /// map was updated.
87  bool tryMap(Value oldVal, Value newVal);
88 
89  /// Drop the last mapping for the given value.
90  void erase(Value value) { mapping.erase(value); }
91 
92  /// Returns the inverse raw value mapping (without recursive query support).
93  DenseMap<Value, SmallVector<Value>> getInverse() const {
95  for (auto &it : mapping.getValueMap())
96  inverse[it.second].push_back(it.first);
97  return inverse;
98  }
99 
100 private:
101  /// Current value mappings.
102  IRMapping mapping;
103 };
104 } // namespace
105 
106 Value ConversionValueMapping::lookupOrDefault(Value from,
107  Type desiredType) const {
108  // If there was no desired type, simply find the leaf value.
109  if (!desiredType) {
110  // If this value had a valid mapping, unmap that value as well in the case
111  // that it was also replaced.
112  while (auto mappedValue = mapping.lookupOrNull(from))
113  from = mappedValue;
114  return from;
115  }
116 
117  // Otherwise, try to find the deepest value that has the desired type.
118  Value desiredValue;
119  do {
120  if (from.getType() == desiredType)
121  desiredValue = from;
122 
123  Value mappedValue = mapping.lookupOrNull(from);
124  if (!mappedValue)
125  break;
126  from = mappedValue;
127  } while (true);
128 
129  // If the desired value was found use it, otherwise default to the leaf value.
130  return desiredValue ? desiredValue : from;
131 }
132 
133 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
134  Value result = lookupOrDefault(from, desiredType);
135  if (result == from || (desiredType && result.getType() != desiredType))
136  return nullptr;
137  return result;
138 }
139 
140 bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
141  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
142  if (it == oldVal)
143  return false;
144  map(oldVal, newVal);
145  return true;
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // Rewriter and Translation State
150 //===----------------------------------------------------------------------===//
151 namespace {
152 /// This class contains a snapshot of the current conversion rewriter state.
153 /// This is useful when saving and undoing a set of rewrites.
154 struct RewriterState {
155  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
156  unsigned numReplacedOps)
157  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
158  numReplacedOps(numReplacedOps) {}
159 
160  /// The current number of rewrites performed.
161  unsigned numRewrites;
162 
163  /// The current number of ignored operations.
164  unsigned numIgnoredOperations;
165 
166  /// The current number of replaced ops that are scheduled for erasure.
167  unsigned numReplacedOps;
168 };
169 
170 //===----------------------------------------------------------------------===//
171 // IR rewrites
172 //===----------------------------------------------------------------------===//
173 
174 /// An IR rewrite that can be committed (upon success) or rolled back (upon
175 /// failure).
176 ///
177 /// The dialect conversion keeps track of IR modifications (requested by the
178 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
179 /// are directly applied to the IR as the rewriter API is used, some are applied
180 /// partially, and some are delayed until the `IRRewrite` objects are committed.
181 class IRRewrite {
182 public:
183  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
184  /// Enum values are ordered, so that they can be used in `classof`: first all
185  /// block rewrites, then all operation rewrites.
186  enum class Kind {
187  // Block rewrites
188  CreateBlock,
189  EraseBlock,
190  InlineBlock,
191  MoveBlock,
192  BlockTypeConversion,
193  ReplaceBlockArg,
194  // Operation rewrites
195  MoveOperation,
196  ModifyOperation,
197  ReplaceOperation,
198  CreateOperation,
199  UnresolvedMaterialization
200  };
201 
202  virtual ~IRRewrite() = default;
203 
204  /// Roll back the rewrite. Operations may be erased during rollback.
205  virtual void rollback() = 0;
206 
207  /// Commit the rewrite. At this point, it is certain that the dialect
208  /// conversion will succeed. All IR modifications, except for operation/block
209  /// erasure, must be performed through the given rewriter.
210  ///
211  /// Instead of erasing operations/blocks, they should merely be unlinked
212  /// commit phase and finally be erased during the cleanup phase. This is
213  /// because internal dialect conversion state (such as `mapping`) may still
214  /// be using them.
215  ///
216  /// Any IR modification that was already performed before the commit phase
217  /// (e.g., insertion of an op) must be communicated to the listener that may
218  /// be attached to the given rewriter.
219  virtual void commit(RewriterBase &rewriter) {}
220 
221  /// Cleanup operations/blocks. Cleanup is called after commit.
222  virtual void cleanup(RewriterBase &rewriter) {}
223 
224  Kind getKind() const { return kind; }
225 
226  static bool classof(const IRRewrite *rewrite) { return true; }
227 
228 protected:
229  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
230  : kind(kind), rewriterImpl(rewriterImpl) {}
231 
232  const ConversionConfig &getConfig() const;
233 
234  const Kind kind;
235  ConversionPatternRewriterImpl &rewriterImpl;
236 };
237 
238 /// A block rewrite.
239 class BlockRewrite : public IRRewrite {
240 public:
241  /// Return the block that this rewrite operates on.
242  Block *getBlock() const { return block; }
243 
244  static bool classof(const IRRewrite *rewrite) {
245  return rewrite->getKind() >= Kind::CreateBlock &&
246  rewrite->getKind() <= Kind::ReplaceBlockArg;
247  }
248 
249 protected:
250  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
251  Block *block)
252  : IRRewrite(kind, rewriterImpl), block(block) {}
253 
254  // The block that this rewrite operates on.
255  Block *block;
256 };
257 
258 /// Creation of a block. Block creations are immediately reflected in the IR.
259 /// There is no extra work to commit the rewrite. During rollback, the newly
260 /// created block is erased.
261 class CreateBlockRewrite : public BlockRewrite {
262 public:
263  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
264  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
265 
266  static bool classof(const IRRewrite *rewrite) {
267  return rewrite->getKind() == Kind::CreateBlock;
268  }
269 
270  void commit(RewriterBase &rewriter) override {
271  // The block was already created and inserted. Just inform the listener.
272  if (auto *listener = rewriter.getListener())
273  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
274  }
275 
276  void rollback() override {
277  // Unlink all of the operations within this block, they will be deleted
278  // separately.
279  auto &blockOps = block->getOperations();
280  while (!blockOps.empty())
281  blockOps.remove(blockOps.begin());
282  block->dropAllUses();
283  if (block->getParent())
284  block->erase();
285  else
286  delete block;
287  }
288 };
289 
290 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
291 /// blocks are immediately unlinked, but only erased during cleanup. This makes
292 /// it easier to rollback a block erasure: the block is simply inserted into its
293 /// original location.
294 class EraseBlockRewrite : public BlockRewrite {
295 public:
296  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
297  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
298  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
299 
300  static bool classof(const IRRewrite *rewrite) {
301  return rewrite->getKind() == Kind::EraseBlock;
302  }
303 
304  ~EraseBlockRewrite() override {
305  assert(!block &&
306  "rewrite was neither rolled back nor committed/cleaned up");
307  }
308 
309  void rollback() override {
310  // The block (owned by this rewrite) was not actually erased yet. It was
311  // just unlinked. Put it back into its original position.
312  assert(block && "expected block");
313  auto &blockList = region->getBlocks();
314  Region::iterator before = insertBeforeBlock
315  ? Region::iterator(insertBeforeBlock)
316  : blockList.end();
317  blockList.insert(before, block);
318  block = nullptr;
319  }
320 
321  void commit(RewriterBase &rewriter) override {
322  // Erase the block.
323  assert(block && "expected block");
324  assert(block->empty() && "expected empty block");
325 
326  // Notify the listener that the block is about to be erased.
327  if (auto *listener =
328  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
329  listener->notifyBlockErased(block);
330  }
331 
332  void cleanup(RewriterBase &rewriter) override {
333  // Erase the block.
334  block->dropAllDefinedValueUses();
335  delete block;
336  block = nullptr;
337  }
338 
339 private:
340  // The region in which this block was previously contained.
341  Region *region;
342 
343  // The original successor of this block before it was unlinked. "nullptr" if
344  // this block was the only block in the region.
345  Block *insertBeforeBlock;
346 };
347 
348 /// Inlining of a block. This rewrite is immediately reflected in the IR.
349 /// Note: This rewrite represents only the inlining of the operations. The
350 /// erasure of the inlined block is a separate rewrite.
351 class InlineBlockRewrite : public BlockRewrite {
352 public:
353  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
354  Block *sourceBlock, Block::iterator before)
355  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
356  sourceBlock(sourceBlock),
357  firstInlinedInst(sourceBlock->empty() ? nullptr
358  : &sourceBlock->front()),
359  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
360  // If a listener is attached to the dialect conversion, ops must be moved
361  // one-by-one. When they are moved in bulk, notifications cannot be sent
362  // because the ops that used to be in the source block at the time of the
363  // inlining (before the "commit" phase) are unknown at the time when
364  // notifications are sent (which is during the "commit" phase).
365  assert(!getConfig().listener &&
366  "InlineBlockRewrite not supported if listener is attached");
367  }
368 
369  static bool classof(const IRRewrite *rewrite) {
370  return rewrite->getKind() == Kind::InlineBlock;
371  }
372 
373  void rollback() override {
374  // Put the operations from the destination block (owned by the rewrite)
375  // back into the source block.
376  if (firstInlinedInst) {
377  assert(lastInlinedInst && "expected operation");
378  sourceBlock->getOperations().splice(sourceBlock->begin(),
379  block->getOperations(),
380  Block::iterator(firstInlinedInst),
381  ++Block::iterator(lastInlinedInst));
382  }
383  }
384 
385 private:
386  // The block that originally contained the operations.
387  Block *sourceBlock;
388 
389  // The first inlined operation.
390  Operation *firstInlinedInst;
391 
392  // The last inlined operation.
393  Operation *lastInlinedInst;
394 };
395 
396 /// Moving of a block. This rewrite is immediately reflected in the IR.
397 class MoveBlockRewrite : public BlockRewrite {
398 public:
399  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
400  Region *region, Block *insertBeforeBlock)
401  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
402  insertBeforeBlock(insertBeforeBlock) {}
403 
404  static bool classof(const IRRewrite *rewrite) {
405  return rewrite->getKind() == Kind::MoveBlock;
406  }
407 
408  void commit(RewriterBase &rewriter) override {
409  // The block was already moved. Just inform the listener.
410  if (auto *listener = rewriter.getListener()) {
411  // Note: `previousIt` cannot be passed because this is a delayed
412  // notification and iterators into past IR state cannot be represented.
413  listener->notifyBlockInserted(block, /*previous=*/region,
414  /*previousIt=*/{});
415  }
416  }
417 
418  void rollback() override {
419  // Move the block back to its original position.
420  Region::iterator before =
421  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
422  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
423  }
424 
425 private:
426  // The region in which this block was previously contained.
427  Region *region;
428 
429  // The original successor of this block before it was moved. "nullptr" if
430  // this block was the only block in the region.
431  Block *insertBeforeBlock;
432 };
433 
434 /// Block type conversion. This rewrite is partially reflected in the IR.
435 class BlockTypeConversionRewrite : public BlockRewrite {
436 public:
437  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
438  Block *block, Block *origBlock,
439  const TypeConverter *converter)
440  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
441  origBlock(origBlock), converter(converter) {}
442 
443  static bool classof(const IRRewrite *rewrite) {
444  return rewrite->getKind() == Kind::BlockTypeConversion;
445  }
446 
447  /// Materialize any necessary conversions for converted arguments that have
448  /// live users, using the provided `findLiveUser` to search for a user that
449  /// survives the conversion process.
450  LogicalResult
451  materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
452 
453  void commit(RewriterBase &rewriter) override;
454 
455  void rollback() override;
456 
457 private:
458  /// The original block that was requested to have its signature converted.
459  Block *origBlock;
460 
461  /// The type converter used to convert the arguments.
462  const TypeConverter *converter;
463 };
464 
465 /// Replacing a block argument. This rewrite is not immediately reflected in the
466 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
467 /// until the rewrite is committed.
468 class ReplaceBlockArgRewrite : public BlockRewrite {
469 public:
470  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
471  Block *block, BlockArgument arg)
472  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
473 
474  static bool classof(const IRRewrite *rewrite) {
475  return rewrite->getKind() == Kind::ReplaceBlockArg;
476  }
477 
478  void commit(RewriterBase &rewriter) override;
479 
480  void rollback() override;
481 
482 private:
483  BlockArgument arg;
484 };
485 
486 /// An operation rewrite.
487 class OperationRewrite : public IRRewrite {
488 public:
489  /// Return the operation that this rewrite operates on.
490  Operation *getOperation() const { return op; }
491 
492  static bool classof(const IRRewrite *rewrite) {
493  return rewrite->getKind() >= Kind::MoveOperation &&
494  rewrite->getKind() <= Kind::UnresolvedMaterialization;
495  }
496 
497 protected:
498  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
499  Operation *op)
500  : IRRewrite(kind, rewriterImpl), op(op) {}
501 
502  // The operation that this rewrite operates on.
503  Operation *op;
504 };
505 
506 /// Moving of an operation. This rewrite is immediately reflected in the IR.
507 class MoveOperationRewrite : public OperationRewrite {
508 public:
509  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
510  Operation *op, Block *block, Operation *insertBeforeOp)
511  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
512  insertBeforeOp(insertBeforeOp) {}
513 
514  static bool classof(const IRRewrite *rewrite) {
515  return rewrite->getKind() == Kind::MoveOperation;
516  }
517 
518  void commit(RewriterBase &rewriter) override {
519  // The operation was already moved. Just inform the listener.
520  if (auto *listener = rewriter.getListener()) {
521  // Note: `previousIt` cannot be passed because this is a delayed
522  // notification and iterators into past IR state cannot be represented.
523  listener->notifyOperationInserted(
524  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
525  /*insertPt=*/{}));
526  }
527  }
528 
529  void rollback() override {
530  // Move the operation back to its original position.
531  Block::iterator before =
532  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
533  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
534  }
535 
536 private:
537  // The block in which this operation was previously contained.
538  Block *block;
539 
540  // The original successor of this operation before it was moved. "nullptr"
541  // if this operation was the only operation in the region.
542  Operation *insertBeforeOp;
543 };
544 
545 /// In-place modification of an op. This rewrite is immediately reflected in
546 /// the IR. The previous state of the operation is stored in this object.
547 class ModifyOperationRewrite : public OperationRewrite {
548 public:
549  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
550  Operation *op)
551  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
552  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
553  operands(op->operand_begin(), op->operand_end()),
554  successors(op->successor_begin(), op->successor_end()) {
555  if (OpaqueProperties prop = op->getPropertiesStorage()) {
556  // Make a copy of the properties.
557  propertiesStorage = operator new(op->getPropertiesStorageSize());
558  OpaqueProperties propCopy(propertiesStorage);
559  name.initOpProperties(propCopy, /*init=*/prop);
560  }
561  }
562 
563  static bool classof(const IRRewrite *rewrite) {
564  return rewrite->getKind() == Kind::ModifyOperation;
565  }
566 
567  ~ModifyOperationRewrite() override {
568  assert(!propertiesStorage &&
569  "rewrite was neither committed nor rolled back");
570  }
571 
572  void commit(RewriterBase &rewriter) override {
573  // Notify the listener that the operation was modified in-place.
574  if (auto *listener =
575  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
576  listener->notifyOperationModified(op);
577 
578  if (propertiesStorage) {
579  OpaqueProperties propCopy(propertiesStorage);
580  // Note: The operation may have been erased in the mean time, so
581  // OperationName must be stored in this object.
582  name.destroyOpProperties(propCopy);
583  operator delete(propertiesStorage);
584  propertiesStorage = nullptr;
585  }
586  }
587 
588  void rollback() override {
589  op->setLoc(loc);
590  op->setAttrs(attrs);
591  op->setOperands(operands);
592  for (const auto &it : llvm::enumerate(successors))
593  op->setSuccessor(it.value(), it.index());
594  if (propertiesStorage) {
595  OpaqueProperties propCopy(propertiesStorage);
596  op->copyProperties(propCopy);
597  name.destroyOpProperties(propCopy);
598  operator delete(propertiesStorage);
599  propertiesStorage = nullptr;
600  }
601  }
602 
603 private:
604  OperationName name;
605  LocationAttr loc;
606  DictionaryAttr attrs;
607  SmallVector<Value, 8> operands;
608  SmallVector<Block *, 2> successors;
609  void *propertiesStorage = nullptr;
610 };
611 
612 /// Replacing an operation. Erasing an operation is treated as a special case
613 /// with "null" replacements. This rewrite is not immediately reflected in the
614 /// IR. An internal IR mapping is updated, but values are not replaced and the
615 /// original op is not erased until the rewrite is committed.
616 class ReplaceOperationRewrite : public OperationRewrite {
617 public:
618  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
619  Operation *op, const TypeConverter *converter,
620  bool changedResults)
621  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
622  converter(converter), changedResults(changedResults) {}
623 
624  static bool classof(const IRRewrite *rewrite) {
625  return rewrite->getKind() == Kind::ReplaceOperation;
626  }
627 
628  void commit(RewriterBase &rewriter) override;
629 
630  void rollback() override;
631 
632  void cleanup(RewriterBase &rewriter) override;
633 
634  const TypeConverter *getConverter() const { return converter; }
635 
636  bool hasChangedResults() const { return changedResults; }
637 
638 private:
639  /// An optional type converter that can be used to materialize conversions
640  /// between the new and old values if necessary.
641  const TypeConverter *converter;
642 
643  /// A boolean flag that indicates whether result types have changed or not.
644  bool changedResults;
645 };
646 
647 class CreateOperationRewrite : public OperationRewrite {
648 public:
649  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
650  Operation *op)
651  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
652 
653  static bool classof(const IRRewrite *rewrite) {
654  return rewrite->getKind() == Kind::CreateOperation;
655  }
656 
657  void commit(RewriterBase &rewriter) override {
658  // The operation was already created and inserted. Just inform the listener.
659  if (auto *listener = rewriter.getListener())
660  listener->notifyOperationInserted(op, /*previous=*/{});
661  }
662 
663  void rollback() override;
664 };
665 
666 /// The type of materialization.
667 enum MaterializationKind {
668  /// This materialization materializes a conversion for an illegal block
669  /// argument type, to the original one.
670  Argument,
671 
672  /// This materialization materializes a conversion from an illegal type to a
673  /// legal one.
674  Target,
675 
676  /// This materialization materializes a conversion from a legal type back to
677  /// an illegal one.
678  Source
679 };
680 
681 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
682 /// op. Unresolved materializations are erased at the end of the dialect
683 /// conversion.
684 class UnresolvedMaterializationRewrite : public OperationRewrite {
685 public:
686  UnresolvedMaterializationRewrite(
687  ConversionPatternRewriterImpl &rewriterImpl,
688  UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
689  MaterializationKind kind = MaterializationKind::Target)
690  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
691  converterAndKind(converter, kind) {}
692 
693  static bool classof(const IRRewrite *rewrite) {
694  return rewrite->getKind() == Kind::UnresolvedMaterialization;
695  }
696 
697  UnrealizedConversionCastOp getOperation() const {
698  return cast<UnrealizedConversionCastOp>(op);
699  }
700 
701  void rollback() override;
702 
703  void cleanup(RewriterBase &rewriter) override;
704 
705  /// Return the type converter of this materialization (which may be null).
706  const TypeConverter *getConverter() const {
707  return converterAndKind.getPointer();
708  }
709 
710  /// Return the kind of this materialization.
711  MaterializationKind getMaterializationKind() const {
712  return converterAndKind.getInt();
713  }
714 
715 private:
716  /// The corresponding type converter to use when resolving this
717  /// materialization, and the kind of this materialization.
718  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
719  converterAndKind;
720 };
721 } // namespace
722 
723 /// Return "true" if there is an operation rewrite that matches the specified
724 /// rewrite type and operation among the given rewrites.
725 template <typename RewriteTy, typename R>
726 static bool hasRewrite(R &&rewrites, Operation *op) {
727  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
728  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
729  return rewriteTy && rewriteTy->getOperation() == op;
730  });
731 }
732 
733 /// Find the single rewrite object of the specified type and block among the
734 /// given rewrites. In debug mode, asserts that there is mo more than one such
735 /// object. Return "nullptr" if no object was found.
736 template <typename RewriteTy, typename R>
737 static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
738  RewriteTy *result = nullptr;
739  for (auto &rewrite : rewrites) {
740  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
741  if (rewriteTy && rewriteTy->getBlock() == block) {
742 #ifndef NDEBUG
743  assert(!result && "expected single matching rewrite");
744  result = rewriteTy;
745 #else
746  return rewriteTy;
747 #endif // NDEBUG
748  }
749  }
750  return result;
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // ConversionPatternRewriterImpl
755 //===----------------------------------------------------------------------===//
756 namespace mlir {
757 namespace detail {
760  const ConversionConfig &config)
761  : context(ctx), config(config) {}
762 
763  //===--------------------------------------------------------------------===//
764  // State Management
765  //===--------------------------------------------------------------------===//
766 
767  /// Return the current state of the rewriter.
768  RewriterState getCurrentState();
769 
770  /// Apply all requested operation rewrites. This method is invoked when the
771  /// conversion process succeeds.
772  void applyRewrites();
773 
774  /// Reset the state of the rewriter to a previously saved point.
775  void resetState(RewriterState state);
776 
777  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
778  /// failure.
779  template <typename RewriteTy, typename... Args>
780  void appendRewrite(Args &&...args) {
781  rewrites.push_back(
782  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
783  }
784 
785  /// Undo the rewrites (motions, splits) one by one in reverse order until
786  /// "numRewritesToKeep" rewrites remains.
787  void undoRewrites(unsigned numRewritesToKeep = 0);
788 
789  /// Remap the given values to those with potentially different types. Returns
790  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
791  /// is the tag used when describing a value within a diagnostic, e.g.
792  /// "operand".
793  LogicalResult remapValues(StringRef valueDiagTag,
794  std::optional<Location> inputLoc,
795  PatternRewriter &rewriter, ValueRange values,
796  SmallVectorImpl<Value> &remapped);
797 
798  /// Return "true" if the given operation is ignored, and does not need to be
799  /// converted.
800  bool isOpIgnored(Operation *op) const;
801 
802  /// Return "true" if the given operation was replaced or erased.
803  bool wasOpReplaced(Operation *op) const;
804 
805  //===--------------------------------------------------------------------===//
806  // Type Conversion
807  //===--------------------------------------------------------------------===//
808 
809  /// Convert the types of block arguments within the given region.
810  FailureOr<Block *>
811  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
812  const TypeConverter &converter,
813  TypeConverter::SignatureConversion *entryConversion);
814 
815  /// Apply the given signature conversion on the given block. The new block
816  /// containing the updated signature is returned. If no conversions were
817  /// necessary, e.g. if the block has no arguments, `block` is returned.
818  /// `converter` is used to generate any necessary cast operations that
819  /// translate between the origin argument types and those specified in the
820  /// signature conversion.
821  Block *applySignatureConversion(
822  ConversionPatternRewriter &rewriter, Block *block,
823  const TypeConverter *converter,
824  TypeConverter::SignatureConversion &signatureConversion);
825 
826  //===--------------------------------------------------------------------===//
827  // Materializations
828  //===--------------------------------------------------------------------===//
829  /// Build an unresolved materialization operation given an output type and set
830  /// of input operands.
831  Value buildUnresolvedMaterialization(MaterializationKind kind,
832  Block *insertBlock,
833  Block::iterator insertPt, Location loc,
834  ValueRange inputs, Type outputType,
835  const TypeConverter *converter);
836 
837  Value buildUnresolvedTargetMaterialization(Location loc, Value input,
838  Type outputType,
839  const TypeConverter *converter);
840 
841  //===--------------------------------------------------------------------===//
842  // Rewriter Notification Hooks
843  //===--------------------------------------------------------------------===//
844 
845  //// Notifies that an op was inserted.
846  void notifyOperationInserted(Operation *op,
847  OpBuilder::InsertPoint previous) override;
848 
849  /// Notifies that an op is about to be replaced with the given values.
850  void notifyOpReplaced(Operation *op, ValueRange newValues);
851 
852  /// Notifies that a block is about to be erased.
853  void notifyBlockIsBeingErased(Block *block);
854 
855  /// Notifies that a block was inserted.
856  void notifyBlockInserted(Block *block, Region *previous,
857  Region::iterator previousIt) override;
858 
859  /// Notifies that a block is being inlined into another block.
860  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
861  Block::iterator before);
862 
863  /// Notifies that a pattern match failed for the given reason.
864  void
865  notifyMatchFailure(Location loc,
866  function_ref<void(Diagnostic &)> reasonCallback) override;
867 
868  //===--------------------------------------------------------------------===//
869  // IR Erasure
870  //===--------------------------------------------------------------------===//
871 
872  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
873  /// operation or block is erased multiple times. This rewriter assumes that
874  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
876  public:
878  : RewriterBase(context, /*listener=*/this) {}
879 
880  /// Erase the given op (unless it was already erased).
881  void eraseOp(Operation *op) override {
882  if (erased.contains(op))
883  return;
884  op->dropAllUses();
886  }
887 
888  /// Erase the given block (unless it was already erased).
889  void eraseBlock(Block *block) override {
890  if (erased.contains(block))
891  return;
892  assert(block->empty() && "expected empty block");
893  block->dropAllDefinedValueUses();
895  }
896 
897  void notifyOperationErased(Operation *op) override { erased.insert(op); }
898 
899  void notifyBlockErased(Block *block) override { erased.insert(block); }
900 
901  /// Pointers to all erased operations and blocks.
903  };
904 
905  //===--------------------------------------------------------------------===//
906  // State
907  //===--------------------------------------------------------------------===//
908 
909  /// MLIR context.
911 
912  // Mapping between replaced values that differ in type. This happens when
913  // replacing a value with one of a different type.
914  ConversionValueMapping mapping;
915 
916  /// Ordered list of block operations (creations, splits, motions).
918 
919  /// A set of operations that should no longer be considered for legalization.
920  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
921  /// tracked separately.
923 
924  /// A set of operations that were replaced/erased. Such ops are not erased
925  /// immediately but only when the dialect conversion succeeds. In the mean
926  /// time, they should no longer be considered for legalization and any attempt
927  /// to modify/access them is invalid rewriter API usage.
929 
930  /// The current type converter, or nullptr if no type converter is currently
931  /// active.
932  const TypeConverter *currentTypeConverter = nullptr;
933 
934  /// A mapping of regions to type converters that should be used when
935  /// converting the arguments of blocks within that region.
937 
938  /// Dialect conversion configuration.
940 
941 #ifndef NDEBUG
942  /// A set of operations that have pending updates. This tracking isn't
943  /// strictly necessary, and is thus only active during debug builds for extra
944  /// verification.
946 
947  /// A logger used to emit diagnostics during the conversion process.
948  llvm::ScopedPrinter logger{llvm::dbgs()};
949 #endif
950 };
951 } // namespace detail
952 } // namespace mlir
953 
954 const ConversionConfig &IRRewrite::getConfig() const {
955  return rewriterImpl.config;
956 }
957 
958 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
959  // Inform the listener about all IR modifications that have already taken
960  // place: References to the original block have been replaced with the new
961  // block.
962  if (auto *listener =
963  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
964  for (Operation *op : block->getUsers())
965  listener->notifyOperationModified(op);
966 }
967 
968 void BlockTypeConversionRewrite::rollback() {
969  block->replaceAllUsesWith(origBlock);
970 }
971 
972 LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
973  function_ref<Operation *(Value)> findLiveUser) {
974  // Process the remapping for each of the original arguments.
975  for (auto it : llvm::enumerate(origBlock->getArguments())) {
976  BlockArgument origArg = it.value();
977  // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
978  OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
979  builder.setInsertionPointToStart(block);
980 
981  // If the type of this argument changed and the argument is still live, we
982  // need to materialize a conversion.
983  if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
984  continue;
985  Operation *liveUser = findLiveUser(origArg);
986  if (!liveUser)
987  continue;
988 
989  Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
990  assert(replacementValue && "replacement value not found");
991  Value newArg;
992  if (converter) {
993  builder.setInsertionPointAfterValue(replacementValue);
994  newArg = converter->materializeSourceConversion(
995  builder, origArg.getLoc(), origArg.getType(), replacementValue);
996  assert((!newArg || newArg.getType() == origArg.getType()) &&
997  "materialization hook did not provide a value of the expected "
998  "type");
999  }
1000  if (!newArg) {
1002  emitError(origArg.getLoc())
1003  << "failed to materialize conversion for block argument #"
1004  << it.index() << " that remained live after conversion, type was "
1005  << origArg.getType();
1006  diag.attachNote(liveUser->getLoc())
1007  << "see existing live user here: " << *liveUser;
1008  return failure();
1009  }
1010  rewriterImpl.mapping.map(origArg, newArg);
1011  }
1012  return success();
1013 }
1014 
1015 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1016  Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
1017  if (!repl)
1018  return;
1019 
1020  if (isa<BlockArgument>(repl)) {
1021  rewriter.replaceAllUsesWith(arg, repl);
1022  return;
1023  }
1024 
1025  // If the replacement value is an operation, we check to make sure that we
1026  // don't replace uses that are within the parent operation of the
1027  // replacement value.
1028  Operation *replOp = cast<OpResult>(repl).getOwner();
1029  Block *replBlock = replOp->getBlock();
1030  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1031  Operation *user = operand.getOwner();
1032  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1033  });
1034 }
1035 
1036 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
1037 
1038 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1039  auto *listener =
1040  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1041 
1042  // Compute replacement values.
1043  SmallVector<Value> replacements =
1044  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1045  return rewriterImpl.mapping.lookupOrNull(result, result.getType());
1046  });
1047 
1048  // Notify the listener that the operation is about to be replaced.
1049  if (listener)
1050  listener->notifyOperationReplaced(op, replacements);
1051 
1052  // Replace all uses with the new values.
1053  for (auto [result, newValue] :
1054  llvm::zip_equal(op->getResults(), replacements))
1055  if (newValue)
1056  rewriter.replaceAllUsesWith(result, newValue);
1057 
1058  // The original op will be erased, so remove it from the set of unlegalized
1059  // ops.
1060  if (getConfig().unlegalizedOps)
1061  getConfig().unlegalizedOps->erase(op);
1062 
1063  // Notify the listener that the operation (and its nested operations) was
1064  // erased.
1065  if (listener) {
1067  [&](Operation *op) { listener->notifyOperationErased(op); });
1068  }
1069 
1070  // Do not erase the operation yet. It may still be referenced in `mapping`.
1071  // Just unlink it for now and erase it during cleanup.
1072  op->getBlock()->getOperations().remove(op);
1073 }
1074 
1075 void ReplaceOperationRewrite::rollback() {
1076  for (auto result : op->getResults())
1077  rewriterImpl.mapping.erase(result);
1078 }
1079 
1080 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1081  rewriter.eraseOp(op);
1082 }
1083 
1084 void CreateOperationRewrite::rollback() {
1085  for (Region &region : op->getRegions()) {
1086  while (!region.getBlocks().empty())
1087  region.getBlocks().remove(region.getBlocks().begin());
1088  }
1089  op->dropAllUses();
1090  op->erase();
1091 }
1092 
1093 void UnresolvedMaterializationRewrite::rollback() {
1094  if (getMaterializationKind() == MaterializationKind::Target) {
1095  for (Value input : op->getOperands())
1096  rewriterImpl.mapping.erase(input);
1097  }
1098  op->erase();
1099 }
1100 
1101 void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
1102  rewriter.eraseOp(op);
1103 }
1104 
1106  // Commit all rewrites.
1107  IRRewriter rewriter(context, config.listener);
1108  for (auto &rewrite : rewrites)
1109  rewrite->commit(rewriter);
1110 
1111  // Clean up all rewrites.
1112  SingleEraseRewriter eraseRewriter(context);
1113  for (auto &rewrite : rewrites)
1114  rewrite->cleanup(eraseRewriter);
1115 }
1116 
1117 //===----------------------------------------------------------------------===//
1118 // State Management
1119 
1121  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1122 }
1123 
1125  // Undo any rewrites.
1126  undoRewrites(state.numRewrites);
1127 
1128  // Pop all of the recorded ignored operations that are no longer valid.
1129  while (ignoredOps.size() != state.numIgnoredOperations)
1130  ignoredOps.pop_back();
1131 
1132  while (replacedOps.size() != state.numReplacedOps)
1133  replacedOps.pop_back();
1134 }
1135 
1136 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1137  for (auto &rewrite :
1138  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1139  rewrite->rollback();
1140  rewrites.resize(numRewritesToKeep);
1141 }
1142 
1144  StringRef valueDiagTag, std::optional<Location> inputLoc,
1145  PatternRewriter &rewriter, ValueRange values,
1146  SmallVectorImpl<Value> &remapped) {
1147  remapped.reserve(llvm::size(values));
1148 
1149  SmallVector<Type, 1> legalTypes;
1150  for (const auto &it : llvm::enumerate(values)) {
1151  Value operand = it.value();
1152  Type origType = operand.getType();
1153 
1154  // If a converter was provided, get the desired legal types for this
1155  // operand.
1156  Type desiredType;
1157  if (currentTypeConverter) {
1158  // If there is no legal conversion, fail to match this pattern.
1159  legalTypes.clear();
1160  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1161  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1162  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1163  diag << "unable to convert type for " << valueDiagTag << " #"
1164  << it.index() << ", type was " << origType;
1165  });
1166  return failure();
1167  }
1168  // TODO: There currently isn't any mechanism to do 1->N type conversion
1169  // via the PatternRewriter replacement API, so for now we just ignore it.
1170  if (legalTypes.size() == 1)
1171  desiredType = legalTypes.front();
1172  } else {
1173  // TODO: What we should do here is just set `desiredType` to `origType`
1174  // and then handle the necessary type conversions after the conversion
1175  // process has finished. Unfortunately a lot of patterns currently rely on
1176  // receiving the new operands even if the types change, so we keep the
1177  // original behavior here for now until all of the patterns relying on
1178  // this get updated.
1179  }
1180  Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1181 
1182  // Handle the case where the conversion was 1->1 and the new operand type
1183  // isn't legal.
1184  Type newOperandType = newOperand.getType();
1185  if (currentTypeConverter && desiredType && newOperandType != desiredType) {
1186  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1188  operandLoc, newOperand, desiredType, currentTypeConverter);
1189  mapping.map(mapping.lookupOrDefault(newOperand), castValue);
1190  newOperand = castValue;
1191  }
1192  remapped.push_back(newOperand);
1193  }
1194  return success();
1195 }
1196 
1198  // Check to see if this operation is ignored or was replaced.
1199  return replacedOps.count(op) || ignoredOps.count(op);
1200 }
1201 
1203  // Check to see if this operation was replaced.
1204  return replacedOps.count(op);
1205 }
1206 
1207 //===----------------------------------------------------------------------===//
1208 // Type Conversion
1209 
1211  ConversionPatternRewriter &rewriter, Region *region,
1212  const TypeConverter &converter,
1213  TypeConverter::SignatureConversion *entryConversion) {
1214  regionToConverter[region] = &converter;
1215  if (region->empty())
1216  return nullptr;
1217 
1218  // Convert the arguments of each non-entry block within the region.
1219  for (Block &block :
1220  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1221  // Compute the signature for the block with the provided converter.
1222  std::optional<TypeConverter::SignatureConversion> conversion =
1223  converter.convertBlockSignature(&block);
1224  if (!conversion)
1225  return failure();
1226  // Convert the block with the computed signature.
1227  applySignatureConversion(rewriter, &block, &converter, *conversion);
1228  }
1229 
1230  // Convert the entry block. If an entry signature conversion was provided,
1231  // use that one. Otherwise, compute the signature with the type converter.
1232  if (entryConversion)
1233  return applySignatureConversion(rewriter, &region->front(), &converter,
1234  *entryConversion);
1235  std::optional<TypeConverter::SignatureConversion> conversion =
1236  converter.convertBlockSignature(&region->front());
1237  if (!conversion)
1238  return failure();
1239  return applySignatureConversion(rewriter, &region->front(), &converter,
1240  *conversion);
1241 }
1242 
1244  ConversionPatternRewriter &rewriter, Block *block,
1245  const TypeConverter *converter,
1246  TypeConverter::SignatureConversion &signatureConversion) {
1247  OpBuilder::InsertionGuard g(rewriter);
1248 
1249  // If no arguments are being changed or added, there is nothing to do.
1250  unsigned origArgCount = block->getNumArguments();
1251  auto convertedTypes = signatureConversion.getConvertedTypes();
1252  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1253  return block;
1254 
1255  // Compute the locations of all block arguments in the new block.
1256  SmallVector<Location> newLocs(convertedTypes.size(),
1257  rewriter.getUnknownLoc());
1258  for (unsigned i = 0; i < origArgCount; ++i) {
1259  auto inputMap = signatureConversion.getInputMapping(i);
1260  if (!inputMap || inputMap->replacementValue)
1261  continue;
1262  Location origLoc = block->getArgument(i).getLoc();
1263  for (unsigned j = 0; j < inputMap->size; ++j)
1264  newLocs[inputMap->inputNo + j] = origLoc;
1265  }
1266 
1267  // Insert a new block with the converted block argument types and move all ops
1268  // from the old block to the new block.
1269  Block *newBlock =
1270  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1271  convertedTypes, newLocs);
1272 
1273  // If a listener is attached to the dialect conversion, ops cannot be moved
1274  // to the destination block in bulk ("fast path"). This is because at the time
1275  // the notifications are sent, it is unknown which ops were moved. Instead,
1276  // ops should be moved one-by-one ("slow path"), so that a separate
1277  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1278  // a bit more efficient, so we try to do that when possible.
1279  bool fastPath = !config.listener;
1280  if (fastPath) {
1281  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1282  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1283  } else {
1284  while (!block->empty())
1285  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1286  }
1287 
1288  // Replace all uses of the old block with the new block.
1289  block->replaceAllUsesWith(newBlock);
1290 
1291  for (unsigned i = 0; i != origArgCount; ++i) {
1292  BlockArgument origArg = block->getArgument(i);
1293  Type origArgType = origArg.getType();
1294 
1295  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1296  signatureConversion.getInputMapping(i);
1297  if (!inputMap) {
1298  // This block argument was dropped and no replacement value was provided.
1299  // Materialize a replacement value "out of thin air".
1301  MaterializationKind::Source, newBlock, newBlock->begin(),
1302  origArg.getLoc(), /*inputs=*/ValueRange(),
1303  /*outputType=*/origArgType, converter);
1304  mapping.map(origArg, repl);
1305  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1306  continue;
1307  }
1308 
1309  if (Value repl = inputMap->replacementValue) {
1310  // This block argument was dropped and a replacement value was provided.
1311  assert(inputMap->size == 0 &&
1312  "invalid to provide a replacement value when the argument isn't "
1313  "dropped");
1314  mapping.map(origArg, repl);
1315  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1316  continue;
1317  }
1318 
1319  // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1320  // dialect conversion. Therefore, we need an argument materialization to
1321  // turn the replacement block arguments into a single SSA value that can be
1322  // used as a replacement.
1323  auto replArgs =
1324  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1326  MaterializationKind::Argument, newBlock, newBlock->begin(),
1327  origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
1328  mapping.map(origArg, argMat);
1329  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1330 
1331  // FIXME: We simply pass through the replacement argument if there wasn't a
1332  // converter, which isn't great as it allows implicit type conversions to
1333  // appear. We should properly restructure this code to handle cases where a
1334  // converter isn't provided and also to properly handle the case where an
1335  // argument materialization is actually a temporary source materialization
1336  // (e.g. in the case of 1->N).
1337  Type legalOutputType;
1338  if (converter)
1339  legalOutputType = converter->convertType(origArgType);
1340  if (legalOutputType && legalOutputType != origArgType) {
1342  origArg.getLoc(), argMat, legalOutputType, converter);
1343  mapping.map(argMat, targetMat);
1344  }
1345  appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1346  }
1347 
1348  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
1349 
1350  // Erase the old block. (It is just unlinked for now and will be erased during
1351  // cleanup.)
1352  rewriter.eraseBlock(block);
1353 
1354  return newBlock;
1355 }
1356 
1357 //===----------------------------------------------------------------------===//
1358 // Materializations
1359 //===----------------------------------------------------------------------===//
1360 
1361 /// Build an unresolved materialization operation given an output type and set
1362 /// of input operands.
1364  MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1365  Location loc, ValueRange inputs, Type outputType,
1366  const TypeConverter *converter) {
1367  // Avoid materializing an unnecessary cast.
1368  if (inputs.size() == 1 && inputs.front().getType() == outputType)
1369  return inputs.front();
1370 
1371  // Create an unresolved materialization. We use a new OpBuilder to avoid
1372  // tracking the materialization like we do for other operations.
1373  OpBuilder builder(outputType.getContext());
1374  builder.setInsertionPoint(insertBlock, insertPt);
1375  auto convertOp =
1376  builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1377  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1378  return convertOp.getResult(0);
1379 }
1381  Location loc, Value input, Type outputType,
1382  const TypeConverter *converter) {
1383  Block *insertBlock = input.getParentBlock();
1384  Block::iterator insertPt = insertBlock->begin();
1385  if (OpResult inputRes = dyn_cast<OpResult>(input))
1386  insertPt = ++inputRes.getOwner()->getIterator();
1387 
1388  return buildUnresolvedMaterialization(MaterializationKind::Target,
1389  insertBlock, insertPt, loc, input,
1390  outputType, converter);
1391 }
1392 
1393 //===----------------------------------------------------------------------===//
1394 // Rewriter Notification Hooks
1395 
1397  Operation *op, OpBuilder::InsertPoint previous) {
1398  LLVM_DEBUG({
1399  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1400  << ")\n";
1401  });
1402  assert(!wasOpReplaced(op->getParentOp()) &&
1403  "attempting to insert into a block within a replaced/erased op");
1404 
1405  if (!previous.isSet()) {
1406  // This is a newly created op.
1407  appendRewrite<CreateOperationRewrite>(op);
1408  return;
1409  }
1410  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1411  ? nullptr
1412  : &*previous.getPoint();
1413  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1414 }
1415 
1417  ValueRange newValues) {
1418  assert(newValues.size() == op->getNumResults());
1419  assert(!ignoredOps.contains(op) && "operation was already replaced");
1420 
1421  // Track if any of the results changed, e.g. erased and replaced with null.
1422  bool resultChanged = false;
1423 
1424  // Create mappings for each of the new result values.
1425  for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
1426  if (!newValue) {
1427  resultChanged = true;
1428  continue;
1429  }
1430  // Remap, and check for any result type changes.
1431  mapping.map(result, newValue);
1432  resultChanged |= (newValue.getType() != result.getType());
1433  }
1434 
1435  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
1436  resultChanged);
1437 
1438  // Mark this operation and all nested ops as replaced.
1439  op->walk([&](Operation *op) { replacedOps.insert(op); });
1440 }
1441 
1443  appendRewrite<EraseBlockRewrite>(block);
1444 }
1445 
1447  Block *block, Region *previous, Region::iterator previousIt) {
1448  assert(!wasOpReplaced(block->getParentOp()) &&
1449  "attempting to insert into a region within a replaced/erased op");
1450  LLVM_DEBUG(
1451  {
1452  Operation *parent = block->getParentOp();
1453  if (parent) {
1454  logger.startLine() << "** Insert Block into : '" << parent->getName()
1455  << "'(" << parent << ")\n";
1456  } else {
1457  logger.startLine()
1458  << "** Insert Block into detached Region (nullptr parent op)'";
1459  }
1460  });
1461 
1462  if (!previous) {
1463  // This is a newly created block.
1464  appendRewrite<CreateBlockRewrite>(block);
1465  return;
1466  }
1467  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1468  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1469 }
1470 
1472  Block *block, Block *srcBlock, Block::iterator before) {
1473  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1474 }
1475 
1477  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1478  LLVM_DEBUG({
1480  reasonCallback(diag);
1481  logger.startLine() << "** Failure : " << diag.str() << "\n";
1482  if (config.notifyCallback)
1484  });
1485 }
1486 
1487 //===----------------------------------------------------------------------===//
1488 // ConversionPatternRewriter
1489 //===----------------------------------------------------------------------===//
1490 
1491 ConversionPatternRewriter::ConversionPatternRewriter(
1492  MLIRContext *ctx, const ConversionConfig &config)
1493  : PatternRewriter(ctx),
1494  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1495  setListener(impl.get());
1496 }
1497 
1499 
1501  assert(op && newOp && "expected non-null op");
1502  replaceOp(op, newOp->getResults());
1503 }
1504 
1506  assert(op->getNumResults() == newValues.size() &&
1507  "incorrect # of replacement values");
1508  LLVM_DEBUG({
1509  impl->logger.startLine()
1510  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1511  });
1512  impl->notifyOpReplaced(op, newValues);
1513 }
1514 
1516  LLVM_DEBUG({
1517  impl->logger.startLine()
1518  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1519  });
1520  SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1521  impl->notifyOpReplaced(op, nullRepls);
1522 }
1523 
1525  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1526  "attempting to erase a block within a replaced/erased op");
1527 
1528  // Mark all ops for erasure.
1529  for (Operation &op : *block)
1530  eraseOp(&op);
1531 
1532  // Unlink the block from its parent region. The block is kept in the rewrite
1533  // object and will be actually destroyed when rewrites are applied. This
1534  // allows us to keep the operations in the block live and undo the removal by
1535  // re-inserting the block.
1536  impl->notifyBlockIsBeingErased(block);
1537  block->getParent()->getBlocks().remove(block);
1538 }
1539 
1541  Block *block, TypeConverter::SignatureConversion &conversion,
1542  const TypeConverter *converter) {
1543  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1544  "attempting to apply a signature conversion to a block within a "
1545  "replaced/erased op");
1546  return impl->applySignatureConversion(*this, block, converter, conversion);
1547 }
1548 
1550  Region *region, const TypeConverter &converter,
1551  TypeConverter::SignatureConversion *entryConversion) {
1552  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1553  "attempting to apply a signature conversion to a block within a "
1554  "replaced/erased op");
1555  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1556 }
1557 
1559  Value to) {
1560  LLVM_DEBUG({
1561  Operation *parentOp = from.getOwner()->getParentOp();
1562  impl->logger.startLine() << "** Replace Argument : '" << from
1563  << "'(in region of '" << parentOp->getName()
1564  << "'(" << from.getOwner()->getParentOp() << ")\n";
1565  });
1566  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
1567  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1568 }
1569 
1571  SmallVector<Value> remappedValues;
1572  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1573  remappedValues)))
1574  return nullptr;
1575  return remappedValues.front();
1576 }
1577 
1578 LogicalResult
1580  SmallVectorImpl<Value> &results) {
1581  if (keys.empty())
1582  return success();
1583  return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1584  results);
1585 }
1586 
1588  Block::iterator before,
1589  ValueRange argValues) {
1590 #ifndef NDEBUG
1591  assert(argValues.size() == source->getNumArguments() &&
1592  "incorrect # of argument replacement values");
1593  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1594  "attempting to inline a block from a replaced/erased op");
1595  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1596  "attempting to inline a block into a replaced/erased op");
1597  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1598  // The source block will be deleted, so it should not have any users (i.e.,
1599  // there should be no predecessors).
1600  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1601  "expected 'source' to have no predecessors");
1602 #endif // NDEBUG
1603 
1604  // If a listener is attached to the dialect conversion, ops cannot be moved
1605  // to the destination block in bulk ("fast path"). This is because at the time
1606  // the notifications are sent, it is unknown which ops were moved. Instead,
1607  // ops should be moved one-by-one ("slow path"), so that a separate
1608  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1609  // a bit more efficient, so we try to do that when possible.
1610  bool fastPath = !impl->config.listener;
1611 
1612  if (fastPath)
1613  impl->notifyBlockBeingInlined(dest, source, before);
1614 
1615  // Replace all uses of block arguments.
1616  for (auto it : llvm::zip(source->getArguments(), argValues))
1617  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1618 
1619  if (fastPath) {
1620  // Move all ops at once.
1621  dest->getOperations().splice(before, source->getOperations());
1622  } else {
1623  // Move op by op.
1624  while (!source->empty())
1625  moveOpBefore(&source->front(), dest, before);
1626  }
1627 
1628  // Erase the source block.
1629  eraseBlock(source);
1630 }
1631 
1633  assert(!impl->wasOpReplaced(op) &&
1634  "attempting to modify a replaced/erased op");
1635 #ifndef NDEBUG
1636  impl->pendingRootUpdates.insert(op);
1637 #endif
1638  impl->appendRewrite<ModifyOperationRewrite>(op);
1639 }
1640 
1642  assert(!impl->wasOpReplaced(op) &&
1643  "attempting to modify a replaced/erased op");
1645  // There is nothing to do here, we only need to track the operation at the
1646  // start of the update.
1647 #ifndef NDEBUG
1648  assert(impl->pendingRootUpdates.erase(op) &&
1649  "operation did not have a pending in-place update");
1650 #endif
1651 }
1652 
1654 #ifndef NDEBUG
1655  assert(impl->pendingRootUpdates.erase(op) &&
1656  "operation did not have a pending in-place update");
1657 #endif
1658  // Erase the last update for this operation.
1659  auto it = llvm::find_if(
1660  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1661  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1662  return modifyRewrite && modifyRewrite->getOperation() == op;
1663  });
1664  assert(it != impl->rewrites.rend() && "no root update started on op");
1665  (*it)->rollback();
1666  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1667  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1668 }
1669 
1671  return *impl;
1672 }
1673 
1674 //===----------------------------------------------------------------------===//
1675 // ConversionPattern
1676 //===----------------------------------------------------------------------===//
1677 
1678 LogicalResult
1680  PatternRewriter &rewriter) const {
1681  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1682  auto &rewriterImpl = dialectRewriter.getImpl();
1683 
1684  // Track the current conversion pattern type converter in the rewriter.
1685  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1686  getTypeConverter());
1687 
1688  // Remap the operands of the operation.
1689  SmallVector<Value, 4> operands;
1690  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1691  op->getOperands(), operands))) {
1692  return failure();
1693  }
1694  return matchAndRewrite(op, operands, dialectRewriter);
1695 }
1696 
1697 //===----------------------------------------------------------------------===//
1698 // OperationLegalizer
1699 //===----------------------------------------------------------------------===//
1700 
1701 namespace {
1702 /// A set of rewrite patterns that can be used to legalize a given operation.
1703 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1704 
1705 /// This class defines a recursive operation legalizer.
1706 class OperationLegalizer {
1707 public:
1708  using LegalizationAction = ConversionTarget::LegalizationAction;
1709 
1710  OperationLegalizer(const ConversionTarget &targetInfo,
1711  const FrozenRewritePatternSet &patterns,
1712  const ConversionConfig &config);
1713 
1714  /// Returns true if the given operation is known to be illegal on the target.
1715  bool isIllegal(Operation *op) const;
1716 
1717  /// Attempt to legalize the given operation. Returns success if the operation
1718  /// was legalized, failure otherwise.
1719  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1720 
1721  /// Returns the conversion target in use by the legalizer.
1722  const ConversionTarget &getTarget() { return target; }
1723 
1724 private:
1725  /// Attempt to legalize the given operation by folding it.
1726  LogicalResult legalizeWithFold(Operation *op,
1727  ConversionPatternRewriter &rewriter);
1728 
1729  /// Attempt to legalize the given operation by applying a pattern. Returns
1730  /// success if the operation was legalized, failure otherwise.
1731  LogicalResult legalizeWithPattern(Operation *op,
1732  ConversionPatternRewriter &rewriter);
1733 
1734  /// Return true if the given pattern may be applied to the given operation,
1735  /// false otherwise.
1736  bool canApplyPattern(Operation *op, const Pattern &pattern,
1737  ConversionPatternRewriter &rewriter);
1738 
1739  /// Legalize the resultant IR after successfully applying the given pattern.
1740  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1741  ConversionPatternRewriter &rewriter,
1742  RewriterState &curState);
1743 
1744  /// Legalizes the actions registered during the execution of a pattern.
1745  LogicalResult
1746  legalizePatternBlockRewrites(Operation *op,
1747  ConversionPatternRewriter &rewriter,
1749  RewriterState &state, RewriterState &newState);
1750  LogicalResult legalizePatternCreatedOperations(
1752  RewriterState &state, RewriterState &newState);
1753  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1755  RewriterState &state,
1756  RewriterState &newState);
1757 
1758  //===--------------------------------------------------------------------===//
1759  // Cost Model
1760  //===--------------------------------------------------------------------===//
1761 
1762  /// Build an optimistic legalization graph given the provided patterns. This
1763  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1764  /// patterns for operations that are not directly legal, but may be
1765  /// transitively legal for the current target given the provided patterns.
1766  void buildLegalizationGraph(
1767  LegalizationPatterns &anyOpLegalizerPatterns,
1769 
1770  /// Compute the benefit of each node within the computed legalization graph.
1771  /// This orders the patterns within 'legalizerPatterns' based upon two
1772  /// criteria:
1773  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1774  /// represent the more direct mapping to the target.
1775  /// 2) When comparing patterns with the same legalization depth, prefer the
1776  /// pattern with the highest PatternBenefit. This allows for users to
1777  /// prefer specific legalizations over others.
1778  void computeLegalizationGraphBenefit(
1779  LegalizationPatterns &anyOpLegalizerPatterns,
1781 
1782  /// Compute the legalization depth when legalizing an operation of the given
1783  /// type.
1784  unsigned computeOpLegalizationDepth(
1785  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1787 
1788  /// Apply the conversion cost model to the given set of patterns, and return
1789  /// the smallest legalization depth of any of the patterns. See
1790  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1791  unsigned applyCostModelToPatterns(
1792  LegalizationPatterns &patterns,
1793  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1795 
1796  /// The current set of patterns that have been applied.
1797  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1798 
1799  /// The legalization information provided by the target.
1800  const ConversionTarget &target;
1801 
1802  /// The pattern applicator to use for conversions.
1803  PatternApplicator applicator;
1804 
1805  /// Dialect conversion configuration.
1806  const ConversionConfig &config;
1807 };
1808 } // namespace
1809 
1810 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1811  const FrozenRewritePatternSet &patterns,
1812  const ConversionConfig &config)
1813  : target(targetInfo), applicator(patterns), config(config) {
1814  // The set of patterns that can be applied to illegal operations to transform
1815  // them into legal ones.
1817  LegalizationPatterns anyOpLegalizerPatterns;
1818 
1819  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1820  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1821 }
1822 
1823 bool OperationLegalizer::isIllegal(Operation *op) const {
1824  return target.isIllegal(op);
1825 }
1826 
1827 LogicalResult
1828 OperationLegalizer::legalize(Operation *op,
1829  ConversionPatternRewriter &rewriter) {
1830 #ifndef NDEBUG
1831  const char *logLineComment =
1832  "//===-------------------------------------------===//\n";
1833 
1834  auto &logger = rewriter.getImpl().logger;
1835 #endif
1836  LLVM_DEBUG({
1837  logger.getOStream() << "\n";
1838  logger.startLine() << logLineComment;
1839  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1840  << op << ") {\n";
1841  logger.indent();
1842 
1843  // If the operation has no regions, just print it here.
1844  if (op->getNumRegions() == 0) {
1845  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1846  logger.getOStream() << "\n\n";
1847  }
1848  });
1849 
1850  // Check if this operation is legal on the target.
1851  if (auto legalityInfo = target.isLegal(op)) {
1852  LLVM_DEBUG({
1853  logSuccess(
1854  logger, "operation marked legal by the target{0}",
1855  legalityInfo->isRecursivelyLegal
1856  ? "; NOTE: operation is recursively legal; skipping internals"
1857  : "");
1858  logger.startLine() << logLineComment;
1859  });
1860 
1861  // If this operation is recursively legal, mark its children as ignored so
1862  // that we don't consider them for legalization.
1863  if (legalityInfo->isRecursivelyLegal) {
1864  op->walk([&](Operation *nested) {
1865  if (op != nested)
1866  rewriter.getImpl().ignoredOps.insert(nested);
1867  });
1868  }
1869 
1870  return success();
1871  }
1872 
1873  // Check to see if the operation is ignored and doesn't need to be converted.
1874  if (rewriter.getImpl().isOpIgnored(op)) {
1875  LLVM_DEBUG({
1876  logSuccess(logger, "operation marked 'ignored' during conversion");
1877  logger.startLine() << logLineComment;
1878  });
1879  return success();
1880  }
1881 
1882  // If the operation isn't legal, try to fold it in-place.
1883  // TODO: Should we always try to do this, even if the op is
1884  // already legal?
1885  if (succeeded(legalizeWithFold(op, rewriter))) {
1886  LLVM_DEBUG({
1887  logSuccess(logger, "operation was folded");
1888  logger.startLine() << logLineComment;
1889  });
1890  return success();
1891  }
1892 
1893  // Otherwise, we need to apply a legalization pattern to this operation.
1894  if (succeeded(legalizeWithPattern(op, rewriter))) {
1895  LLVM_DEBUG({
1896  logSuccess(logger, "");
1897  logger.startLine() << logLineComment;
1898  });
1899  return success();
1900  }
1901 
1902  LLVM_DEBUG({
1903  logFailure(logger, "no matched legalization pattern");
1904  logger.startLine() << logLineComment;
1905  });
1906  return failure();
1907 }
1908 
1909 LogicalResult
1910 OperationLegalizer::legalizeWithFold(Operation *op,
1911  ConversionPatternRewriter &rewriter) {
1912  auto &rewriterImpl = rewriter.getImpl();
1913  RewriterState curState = rewriterImpl.getCurrentState();
1914 
1915  LLVM_DEBUG({
1916  rewriterImpl.logger.startLine() << "* Fold {\n";
1917  rewriterImpl.logger.indent();
1918  });
1919 
1920  // Try to fold the operation.
1921  SmallVector<Value, 2> replacementValues;
1922  rewriter.setInsertionPoint(op);
1923  if (failed(rewriter.tryFold(op, replacementValues))) {
1924  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1925  return failure();
1926  }
1927  // An empty list of replacement values indicates that the fold was in-place.
1928  // As the operation changed, a new legalization needs to be attempted.
1929  if (replacementValues.empty())
1930  return legalize(op, rewriter);
1931 
1932  // Insert a replacement for 'op' with the folded replacement values.
1933  rewriter.replaceOp(op, replacementValues);
1934 
1935  // Recursively legalize any new constant operations.
1936  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
1937  i != e; ++i) {
1938  auto *createOp =
1939  dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
1940  if (!createOp)
1941  continue;
1942  if (failed(legalize(createOp->getOperation(), rewriter))) {
1943  LLVM_DEBUG(logFailure(rewriterImpl.logger,
1944  "failed to legalize generated constant '{0}'",
1945  createOp->getOperation()->getName()));
1946  rewriterImpl.resetState(curState);
1947  return failure();
1948  }
1949  }
1950 
1951  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1952  return success();
1953 }
1954 
1955 LogicalResult
1956 OperationLegalizer::legalizeWithPattern(Operation *op,
1957  ConversionPatternRewriter &rewriter) {
1958  auto &rewriterImpl = rewriter.getImpl();
1959 
1960  // Functor that returns if the given pattern may be applied.
1961  auto canApply = [&](const Pattern &pattern) {
1962  bool canApply = canApplyPattern(op, pattern, rewriter);
1963  if (canApply && config.listener)
1964  config.listener->notifyPatternBegin(pattern, op);
1965  return canApply;
1966  };
1967 
1968  // Functor that cleans up the rewriter state after a pattern failed to match.
1969  RewriterState curState = rewriterImpl.getCurrentState();
1970  auto onFailure = [&](const Pattern &pattern) {
1971  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
1972  LLVM_DEBUG({
1973  logFailure(rewriterImpl.logger, "pattern failed to match");
1974  if (rewriterImpl.config.notifyCallback) {
1976  diag << "Failed to apply pattern \"" << pattern.getDebugName()
1977  << "\" on op:\n"
1978  << *op;
1979  rewriterImpl.config.notifyCallback(diag);
1980  }
1981  });
1982  if (config.listener)
1983  config.listener->notifyPatternEnd(pattern, failure());
1984  rewriterImpl.resetState(curState);
1985  appliedPatterns.erase(&pattern);
1986  };
1987 
1988  // Functor that performs additional legalization when a pattern is
1989  // successfully applied.
1990  auto onSuccess = [&](const Pattern &pattern) {
1991  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
1992  auto result = legalizePatternResult(op, pattern, rewriter, curState);
1993  appliedPatterns.erase(&pattern);
1994  if (failed(result))
1995  rewriterImpl.resetState(curState);
1996  if (config.listener)
1997  config.listener->notifyPatternEnd(pattern, result);
1998  return result;
1999  };
2000 
2001  // Try to match and rewrite a pattern on this operation.
2002  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2003  onSuccess);
2004 }
2005 
2006 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2007  ConversionPatternRewriter &rewriter) {
2008  LLVM_DEBUG({
2009  auto &os = rewriter.getImpl().logger;
2010  os.getOStream() << "\n";
2011  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2012  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2013  os.getOStream() << ")' {\n";
2014  os.indent();
2015  });
2016 
2017  // Ensure that we don't cycle by not allowing the same pattern to be
2018  // applied twice in the same recursion stack if it is not known to be safe.
2019  if (!pattern.hasBoundedRewriteRecursion() &&
2020  !appliedPatterns.insert(&pattern).second) {
2021  LLVM_DEBUG(
2022  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2023  return false;
2024  }
2025  return true;
2026 }
2027 
2028 LogicalResult
2029 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2030  ConversionPatternRewriter &rewriter,
2031  RewriterState &curState) {
2032  auto &impl = rewriter.getImpl();
2033 
2034 #ifndef NDEBUG
2035  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2036  // Check that the root was either replaced or updated in place.
2037  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2038  auto replacedRoot = [&] {
2039  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2040  };
2041  auto updatedRootInPlace = [&] {
2042  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2043  };
2044  assert((replacedRoot() || updatedRootInPlace()) &&
2045  "expected pattern to replace the root operation");
2046 #endif // NDEBUG
2047 
2048  // Legalize each of the actions registered during application.
2049  RewriterState newState = impl.getCurrentState();
2050  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2051  newState)) ||
2052  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2053  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2054  newState))) {
2055  return failure();
2056  }
2057 
2058  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2059  return success();
2060 }
2061 
2062 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2063  Operation *op, ConversionPatternRewriter &rewriter,
2064  ConversionPatternRewriterImpl &impl, RewriterState &state,
2065  RewriterState &newState) {
2066  SmallPtrSet<Operation *, 16> operationsToIgnore;
2067 
2068  // If the pattern moved or created any blocks, make sure the types of block
2069  // arguments get legalized.
2070  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2071  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2072  if (!rewrite)
2073  continue;
2074  Block *block = rewrite->getBlock();
2075  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2076  ReplaceBlockArgRewrite>(rewrite))
2077  continue;
2078  // Only check blocks outside of the current operation.
2079  Operation *parentOp = block->getParentOp();
2080  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2081  continue;
2082 
2083  // If the region of the block has a type converter, try to convert the block
2084  // directly.
2085  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2086  std::optional<TypeConverter::SignatureConversion> conversion =
2087  converter->convertBlockSignature(block);
2088  if (!conversion) {
2089  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2090  "block"));
2091  return failure();
2092  }
2093  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2094  continue;
2095  }
2096 
2097  // Otherwise, check that this operation isn't one generated by this pattern.
2098  // This is because we will attempt to legalize the parent operation, and
2099  // blocks in regions created by this pattern will already be legalized later
2100  // on. If we haven't built the set yet, build it now.
2101  if (operationsToIgnore.empty()) {
2102  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2103  ++i) {
2104  auto *createOp =
2105  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2106  if (!createOp)
2107  continue;
2108  operationsToIgnore.insert(createOp->getOperation());
2109  }
2110  }
2111 
2112  // If this operation should be considered for re-legalization, try it.
2113  if (operationsToIgnore.insert(parentOp).second &&
2114  failed(legalize(parentOp, rewriter))) {
2115  LLVM_DEBUG(logFailure(impl.logger,
2116  "operation '{0}'({1}) became illegal after rewrite",
2117  parentOp->getName(), parentOp));
2118  return failure();
2119  }
2120  }
2121  return success();
2122 }
2123 
2124 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2126  RewriterState &state, RewriterState &newState) {
2127  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2128  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2129  if (!createOp)
2130  continue;
2131  Operation *op = createOp->getOperation();
2132  if (failed(legalize(op, rewriter))) {
2133  LLVM_DEBUG(logFailure(impl.logger,
2134  "failed to legalize generated operation '{0}'({1})",
2135  op->getName(), op));
2136  return failure();
2137  }
2138  }
2139  return success();
2140 }
2141 
2142 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2144  RewriterState &state, RewriterState &newState) {
2145  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2146  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2147  if (!rewrite)
2148  continue;
2149  Operation *op = rewrite->getOperation();
2150  if (failed(legalize(op, rewriter))) {
2151  LLVM_DEBUG(logFailure(
2152  impl.logger, "failed to legalize operation updated in-place '{0}'",
2153  op->getName()));
2154  return failure();
2155  }
2156  }
2157  return success();
2158 }
2159 
2160 //===----------------------------------------------------------------------===//
2161 // Cost Model
2162 
2163 void OperationLegalizer::buildLegalizationGraph(
2164  LegalizationPatterns &anyOpLegalizerPatterns,
2165  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2166  // A mapping between an operation and a set of operations that can be used to
2167  // generate it.
2169  // A mapping between an operation and any currently invalid patterns it has.
2171  // A worklist of patterns to consider for legality.
2172  SetVector<const Pattern *> patternWorklist;
2173 
2174  // Build the mapping from operations to the parent ops that may generate them.
2175  applicator.walkAllPatterns([&](const Pattern &pattern) {
2176  std::optional<OperationName> root = pattern.getRootKind();
2177 
2178  // If the pattern has no specific root, we can't analyze the relationship
2179  // between the root op and generated operations. Given that, add all such
2180  // patterns to the legalization set.
2181  if (!root) {
2182  anyOpLegalizerPatterns.push_back(&pattern);
2183  return;
2184  }
2185 
2186  // Skip operations that are always known to be legal.
2187  if (target.getOpAction(*root) == LegalizationAction::Legal)
2188  return;
2189 
2190  // Add this pattern to the invalid set for the root op and record this root
2191  // as a parent for any generated operations.
2192  invalidPatterns[*root].insert(&pattern);
2193  for (auto op : pattern.getGeneratedOps())
2194  parentOps[op].insert(*root);
2195 
2196  // Add this pattern to the worklist.
2197  patternWorklist.insert(&pattern);
2198  });
2199 
2200  // If there are any patterns that don't have a specific root kind, we can't
2201  // make direct assumptions about what operations will never be legalized.
2202  // Note: Technically we could, but it would require an analysis that may
2203  // recurse into itself. It would be better to perform this kind of filtering
2204  // at a higher level than here anyways.
2205  if (!anyOpLegalizerPatterns.empty()) {
2206  for (const Pattern *pattern : patternWorklist)
2207  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2208  return;
2209  }
2210 
2211  while (!patternWorklist.empty()) {
2212  auto *pattern = patternWorklist.pop_back_val();
2213 
2214  // Check to see if any of the generated operations are invalid.
2215  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2216  std::optional<LegalizationAction> action = target.getOpAction(op);
2217  return !legalizerPatterns.count(op) &&
2218  (!action || action == LegalizationAction::Illegal);
2219  }))
2220  continue;
2221 
2222  // Otherwise, if all of the generated operation are valid, this op is now
2223  // legal so add all of the child patterns to the worklist.
2224  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2225  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2226 
2227  // Add any invalid patterns of the parent operations to see if they have now
2228  // become legal.
2229  for (auto op : parentOps[*pattern->getRootKind()])
2230  patternWorklist.set_union(invalidPatterns[op]);
2231  }
2232 }
2233 
2234 void OperationLegalizer::computeLegalizationGraphBenefit(
2235  LegalizationPatterns &anyOpLegalizerPatterns,
2236  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2237  // The smallest pattern depth, when legalizing an operation.
2238  DenseMap<OperationName, unsigned> minOpPatternDepth;
2239 
2240  // For each operation that is transitively legal, compute a cost for it.
2241  for (auto &opIt : legalizerPatterns)
2242  if (!minOpPatternDepth.count(opIt.first))
2243  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2244  legalizerPatterns);
2245 
2246  // Apply the cost model to the patterns that can match any operation. Those
2247  // with a specific operation type are already resolved when computing the op
2248  // legalization depth.
2249  if (!anyOpLegalizerPatterns.empty())
2250  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2251  legalizerPatterns);
2252 
2253  // Apply a cost model to the pattern applicator. We order patterns first by
2254  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2255  // decreasing benefit.
2256  applicator.applyCostModel([&](const Pattern &pattern) {
2257  ArrayRef<const Pattern *> orderedPatternList;
2258  if (std::optional<OperationName> rootName = pattern.getRootKind())
2259  orderedPatternList = legalizerPatterns[*rootName];
2260  else
2261  orderedPatternList = anyOpLegalizerPatterns;
2262 
2263  // If the pattern is not found, then it was removed and cannot be matched.
2264  auto *it = llvm::find(orderedPatternList, &pattern);
2265  if (it == orderedPatternList.end())
2267 
2268  // Patterns found earlier in the list have higher benefit.
2269  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2270  });
2271 }
2272 
2273 unsigned OperationLegalizer::computeOpLegalizationDepth(
2274  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2275  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2276  // Check for existing depth.
2277  auto depthIt = minOpPatternDepth.find(op);
2278  if (depthIt != minOpPatternDepth.end())
2279  return depthIt->second;
2280 
2281  // If a mapping for this operation does not exist, then this operation
2282  // is always legal. Return 0 as the depth for a directly legal operation.
2283  auto opPatternsIt = legalizerPatterns.find(op);
2284  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2285  return 0u;
2286 
2287  // Record this initial depth in case we encounter this op again when
2288  // recursively computing the depth.
2289  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2290 
2291  // Apply the cost model to the operation patterns, and update the minimum
2292  // depth.
2293  unsigned minDepth = applyCostModelToPatterns(
2294  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2295  minOpPatternDepth[op] = minDepth;
2296  return minDepth;
2297 }
2298 
2299 unsigned OperationLegalizer::applyCostModelToPatterns(
2300  LegalizationPatterns &patterns,
2301  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2302  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2303  unsigned minDepth = std::numeric_limits<unsigned>::max();
2304 
2305  // Compute the depth for each pattern within the set.
2307  patternsByDepth.reserve(patterns.size());
2308  for (const Pattern *pattern : patterns) {
2309  unsigned depth = 1;
2310  for (auto generatedOp : pattern->getGeneratedOps()) {
2311  unsigned generatedOpDepth = computeOpLegalizationDepth(
2312  generatedOp, minOpPatternDepth, legalizerPatterns);
2313  depth = std::max(depth, generatedOpDepth + 1);
2314  }
2315  patternsByDepth.emplace_back(pattern, depth);
2316 
2317  // Update the minimum depth of the pattern list.
2318  minDepth = std::min(minDepth, depth);
2319  }
2320 
2321  // If the operation only has one legalization pattern, there is no need to
2322  // sort them.
2323  if (patternsByDepth.size() == 1)
2324  return minDepth;
2325 
2326  // Sort the patterns by those likely to be the most beneficial.
2327  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2328  [](const std::pair<const Pattern *, unsigned> &lhs,
2329  const std::pair<const Pattern *, unsigned> &rhs) {
2330  // First sort by the smaller pattern legalization
2331  // depth.
2332  if (lhs.second != rhs.second)
2333  return lhs.second < rhs.second;
2334 
2335  // Then sort by the larger pattern benefit.
2336  auto lhsBenefit = lhs.first->getBenefit();
2337  auto rhsBenefit = rhs.first->getBenefit();
2338  return lhsBenefit > rhsBenefit;
2339  });
2340 
2341  // Update the legalization pattern to use the new sorted list.
2342  patterns.clear();
2343  for (auto &patternIt : patternsByDepth)
2344  patterns.push_back(patternIt.first);
2345  return minDepth;
2346 }
2347 
2348 //===----------------------------------------------------------------------===//
2349 // OperationConverter
2350 //===----------------------------------------------------------------------===//
2351 namespace {
2352 enum OpConversionMode {
2353  /// In this mode, the conversion will ignore failed conversions to allow
2354  /// illegal operations to co-exist in the IR.
2355  Partial,
2356 
2357  /// In this mode, all operations must be legal for the given target for the
2358  /// conversion to succeed.
2359  Full,
2360 
2361  /// In this mode, operations are analyzed for legality. No actual rewrites are
2362  /// applied to the operations on success.
2363  Analysis,
2364 };
2365 } // namespace
2366 
2367 namespace mlir {
2368 // This class converts operations to a given conversion target via a set of
2369 // rewrite patterns. The conversion behaves differently depending on the
2370 // conversion mode.
2372  explicit OperationConverter(const ConversionTarget &target,
2373  const FrozenRewritePatternSet &patterns,
2374  const ConversionConfig &config,
2375  OpConversionMode mode)
2376  : config(config), opLegalizer(target, patterns, this->config),
2377  mode(mode) {}
2378 
2379  /// Converts the given operations to the conversion target.
2380  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2381 
2382 private:
2383  /// Converts an operation with the given rewriter.
2384  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2385 
2386  /// This method is called after the conversion process to legalize any
2387  /// remaining artifacts and complete the conversion.
2388  LogicalResult finalize(ConversionPatternRewriter &rewriter);
2389 
2390  /// Legalize the types of converted block arguments.
2391  LogicalResult
2392  legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2393  ConversionPatternRewriterImpl &rewriterImpl);
2394 
2395  /// Legalize any unresolved type materializations.
2396  LogicalResult legalizeUnresolvedMaterializations(
2397  ConversionPatternRewriter &rewriter,
2398  ConversionPatternRewriterImpl &rewriterImpl,
2399  std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
2400 
2401  /// Legalize an operation result that was marked as "erased".
2402  LogicalResult
2403  legalizeErasedResult(Operation *op, OpResult result,
2404  ConversionPatternRewriterImpl &rewriterImpl);
2405 
2406  /// Legalize an operation result that was replaced with a value of a different
2407  /// type.
2408  LogicalResult legalizeChangedResultType(
2409  Operation *op, OpResult result, Value newValue,
2410  const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2411  ConversionPatternRewriterImpl &rewriterImpl,
2412  const DenseMap<Value, SmallVector<Value>> &inverseMapping);
2413 
2414  /// Dialect conversion configuration.
2415  ConversionConfig config;
2416 
2417  /// The legalizer to use when converting operations.
2418  OperationLegalizer opLegalizer;
2419 
2420  /// The conversion mode to use when legalizing operations.
2421  OpConversionMode mode;
2422 };
2423 } // namespace mlir
2424 
2425 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2426  Operation *op) {
2427  // Legalize the given operation.
2428  if (failed(opLegalizer.legalize(op, rewriter))) {
2429  // Handle the case of a failed conversion for each of the different modes.
2430  // Full conversions expect all operations to be converted.
2431  if (mode == OpConversionMode::Full)
2432  return op->emitError()
2433  << "failed to legalize operation '" << op->getName() << "'";
2434  // Partial conversions allow conversions to fail iff the operation was not
2435  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2436  // set, non-legalizable ops are added to that set.
2437  if (mode == OpConversionMode::Partial) {
2438  if (opLegalizer.isIllegal(op))
2439  return op->emitError()
2440  << "failed to legalize operation '" << op->getName()
2441  << "' that was explicitly marked illegal";
2442  if (config.unlegalizedOps)
2443  config.unlegalizedOps->insert(op);
2444  }
2445  } else if (mode == OpConversionMode::Analysis) {
2446  // Analysis conversions don't fail if any operations fail to legalize,
2447  // they are only interested in the operations that were successfully
2448  // legalized.
2449  if (config.legalizableOps)
2450  config.legalizableOps->insert(op);
2451  }
2452  return success();
2453 }
2454 
2456  if (ops.empty())
2457  return success();
2458  const ConversionTarget &target = opLegalizer.getTarget();
2459 
2460  // Compute the set of operations and blocks to convert.
2461  SmallVector<Operation *> toConvert;
2462  for (auto *op : ops) {
2464  [&](Operation *op) {
2465  toConvert.push_back(op);
2466  // Don't check this operation's children for conversion if the
2467  // operation is recursively legal.
2468  auto legalityInfo = target.isLegal(op);
2469  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2470  return WalkResult::skip();
2471  return WalkResult::advance();
2472  });
2473  }
2474 
2475  // Convert each operation and discard rewrites on failure.
2476  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2477  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2478 
2479  for (auto *op : toConvert)
2480  if (failed(convert(rewriter, op)))
2481  return rewriterImpl.undoRewrites(), failure();
2482 
2483  // Now that all of the operations have been converted, finalize the conversion
2484  // process to ensure any lingering conversion artifacts are cleaned up and
2485  // legalized.
2486  if (failed(finalize(rewriter)))
2487  return rewriterImpl.undoRewrites(), failure();
2488 
2489  // After a successful conversion, apply rewrites if this is not an analysis
2490  // conversion.
2491  if (mode == OpConversionMode::Analysis) {
2492  rewriterImpl.undoRewrites();
2493  } else {
2494  rewriterImpl.applyRewrites();
2495  }
2496  return success();
2497 }
2498 
2499 LogicalResult
2500 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2501  std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
2502  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2503  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2504  inverseMapping)) ||
2505  failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2506  return failure();
2507 
2508  // Process requested operation replacements.
2509  for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
2510  auto *opReplacement =
2511  dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
2512  if (!opReplacement || !opReplacement->hasChangedResults())
2513  continue;
2514  Operation *op = opReplacement->getOperation();
2515  for (OpResult result : op->getResults()) {
2516  Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2517 
2518  // If the operation result was replaced with null, all of the uses of this
2519  // value should be replaced.
2520  if (!newValue) {
2521  if (failed(legalizeErasedResult(op, result, rewriterImpl)))
2522  return failure();
2523  continue;
2524  }
2525 
2526  // Otherwise, check to see if the type of the result changed.
2527  if (result.getType() == newValue.getType())
2528  continue;
2529 
2530  // Compute the inverse mapping only if it is really needed.
2531  if (!inverseMapping)
2532  inverseMapping = rewriterImpl.mapping.getInverse();
2533 
2534  // Legalize this result.
2535  rewriter.setInsertionPoint(op);
2536  if (failed(legalizeChangedResultType(
2537  op, result, newValue, opReplacement->getConverter(), rewriter,
2538  rewriterImpl, *inverseMapping)))
2539  return failure();
2540  }
2541  }
2542  return success();
2543 }
2544 
2545 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2546  ConversionPatternRewriter &rewriter,
2547  ConversionPatternRewriterImpl &rewriterImpl) {
2548  // Functor used to check if all users of a value will be dead after
2549  // conversion.
2550  auto findLiveUser = [&](Value val) {
2551  auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2552  return rewriterImpl.isOpIgnored(user);
2553  });
2554  return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2555  };
2556  // Note: `rewrites` may be reallocated as the loop is running.
2557  for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
2558  ++i) {
2559  auto &rewrite = rewriterImpl.rewrites[i];
2560  if (auto *blockTypeConversionRewrite =
2561  dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
2562  if (failed(blockTypeConversionRewrite->materializeLiveConversions(
2563  findLiveUser)))
2564  return failure();
2565  }
2566  return success();
2567 }
2568 
2569 /// Replace the results of a materialization operation with the given values.
2570 static void
2572  ResultRange matResults, ValueRange values,
2573  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2574  matResults.replaceAllUsesWith(values);
2575 
2576  // For each of the materialization results, update the inverse mappings to
2577  // point to the replacement values.
2578  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
2579  auto inverseMapIt = inverseMapping.find(matResult);
2580  if (inverseMapIt == inverseMapping.end())
2581  continue;
2582 
2583  // Update the reverse mapping, or remove the mapping if we couldn't update
2584  // it. Not being able to update signals that the mapping would have become
2585  // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
2586  // propagated through temporary materializations. We simply drop the
2587  // mapping, and let the post-conversion replacement logic handle updating
2588  // uses.
2589  for (Value inverseMapVal : inverseMapIt->second)
2590  if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
2591  rewriterImpl.mapping.erase(inverseMapVal);
2592  }
2593 }
2594 
2595 /// Compute all of the unresolved materializations that will persist beyond the
2596 /// conversion process, and require inserting a proper user materialization for.
2599  &materializationOps,
2600  ConversionPatternRewriter &rewriter,
2601  ConversionPatternRewriterImpl &rewriterImpl,
2602  DenseMap<Value, SmallVector<Value>> &inverseMapping,
2603  SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
2604  // Helper function to check if the given value or a not yet materialized
2605  // replacement of the given value is live.
2606  // Note: `inverseMapping` maps from replaced values to original values.
2607  auto isLive = [&](Value value) {
2608  auto findFn = [&](Operation *user) {
2609  auto matIt = materializationOps.find(user);
2610  if (matIt != materializationOps.end())
2611  return !necessaryMaterializations.count(matIt->second);
2612  return rewriterImpl.isOpIgnored(user);
2613  };
2614  // A worklist is needed because a value may have gone through a chain of
2615  // replacements and each of the replaced values may have live users.
2616  SmallVector<Value> worklist;
2617  worklist.push_back(value);
2618  while (!worklist.empty()) {
2619  Value next = worklist.pop_back_val();
2620  if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
2621  return true;
2622  // This value may be replacing another value that has a live user.
2623  llvm::append_range(worklist, inverseMapping.lookup(next));
2624  }
2625  return false;
2626  };
2627 
2628  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
2629  [&](Value invalidRoot, Value value, Type type) {
2630  // Check to see if the input operation was remapped to a variant of the
2631  // output.
2632  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2633  if (remappedValue.getType() == type && remappedValue != invalidRoot)
2634  return remappedValue;
2635 
2636  // Check to see if the input is a materialization operation that
2637  // provides an inverse conversion. We just check blindly for
2638  // UnrealizedConversionCastOp here, but it has no effect on correctness.
2639  auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
2640  if (inputCastOp && inputCastOp->getNumOperands() == 1)
2641  return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2642  type);
2643 
2644  return Value();
2645  };
2646 
2648  for (auto &rewrite : rewriterImpl.rewrites) {
2649  auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
2650  if (!mat)
2651  continue;
2652  materializationOps.try_emplace(mat->getOperation(), mat);
2653  worklist.insert(mat);
2654  }
2655  while (!worklist.empty()) {
2656  UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
2657  UnrealizedConversionCastOp op = mat->getOperation();
2658 
2659  // We currently only handle target materializations here.
2660  assert(op->getNumResults() == 1 && "unexpected materialization type");
2661  OpResult opResult = op->getOpResult(0);
2662  Type outputType = opResult.getType();
2663  Operation::operand_range inputOperands = op.getOperands();
2664 
2665  // Try to forward propagate operands for user conversion casts that result
2666  // in the input types of the current cast.
2667  for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
2668  auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2669  if (!castOp)
2670  continue;
2671  if (castOp->getResultTypes() == inputOperands.getTypes()) {
2672  replaceMaterialization(rewriterImpl, opResult, inputOperands,
2673  inverseMapping);
2674  necessaryMaterializations.remove(materializationOps.lookup(user));
2675  }
2676  }
2677 
2678  // Try to avoid materializing a resolved materialization if possible.
2679  // Handle the case of a 1-1 materialization.
2680  if (inputOperands.size() == 1) {
2681  // Check to see if the input operation was remapped to a variant of the
2682  // output.
2683  Value remappedValue =
2684  lookupRemappedValue(opResult, inputOperands[0], outputType);
2685  if (remappedValue && remappedValue != opResult) {
2686  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2687  inverseMapping);
2688  necessaryMaterializations.remove(mat);
2689  continue;
2690  }
2691  } else {
2692  // TODO: Avoid materializing other types of conversions here.
2693  }
2694 
2695  // If the materialization does not have any live users, we don't need to
2696  // generate a user materialization for it.
2697  bool isMaterializationLive = isLive(opResult);
2698  if (!isMaterializationLive)
2699  continue;
2700  if (!necessaryMaterializations.insert(mat))
2701  continue;
2702 
2703  // Reprocess input materializations to see if they have an updated status.
2704  for (Value input : inputOperands) {
2705  if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2706  if (auto *mat = materializationOps.lookup(parentOp))
2707  worklist.insert(mat);
2708  }
2709  }
2710  }
2711 }
2712 
2713 /// Legalize the given unresolved materialization. Returns success if the
2714 /// materialization was legalized, failure otherise.
2716  UnresolvedMaterializationRewrite &mat,
2718  &materializationOps,
2719  ConversionPatternRewriter &rewriter,
2720  ConversionPatternRewriterImpl &rewriterImpl,
2721  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2722  auto findLiveUser = [&](auto &&users) {
2723  auto liveUserIt = llvm::find_if_not(
2724  users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
2725  return liveUserIt == users.end() ? nullptr : *liveUserIt;
2726  };
2727 
2728  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
2729  [&](Value value, Type type) {
2730  // Check to see if the input operation was remapped to a variant of the
2731  // output.
2732  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2733  if (remappedValue.getType() == type)
2734  return remappedValue;
2735  return Value();
2736  };
2737 
2738  UnrealizedConversionCastOp op = mat.getOperation();
2739  if (!rewriterImpl.ignoredOps.insert(op))
2740  return success();
2741 
2742  // We currently only handle target materializations here.
2743  OpResult opResult = op->getOpResult(0);
2744  Operation::operand_range inputOperands = op.getOperands();
2745  Type outputType = opResult.getType();
2746 
2747  // If any input to this materialization is another materialization, resolve
2748  // the input first.
2749  for (Value value : op->getOperands()) {
2750  auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
2751  if (!valueCast)
2752  continue;
2753 
2754  auto matIt = materializationOps.find(valueCast);
2755  if (matIt != materializationOps.end())
2757  *matIt->second, materializationOps, rewriter, rewriterImpl,
2758  inverseMapping)))
2759  return failure();
2760  }
2761 
2762  // Perform a last ditch attempt to avoid materializing a resolved
2763  // materialization if possible.
2764  // Handle the case of a 1-1 materialization.
2765  if (inputOperands.size() == 1) {
2766  // Check to see if the input operation was remapped to a variant of the
2767  // output.
2768  Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2769  if (remappedValue && remappedValue != opResult) {
2770  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2771  inverseMapping);
2772  return success();
2773  }
2774  } else {
2775  // TODO: Avoid materializing other types of conversions here.
2776  }
2777 
2778  // Try to materialize the conversion.
2779  if (const TypeConverter *converter = mat.getConverter()) {
2780  rewriter.setInsertionPoint(op);
2781  Value newMaterialization;
2782  switch (mat.getMaterializationKind()) {
2784  // Try to materialize an argument conversion.
2785  newMaterialization = converter->materializeArgumentConversion(
2786  rewriter, op->getLoc(), outputType, inputOperands);
2787  if (newMaterialization)
2788  break;
2789  // If an argument materialization failed, fallback to trying a target
2790  // materialization.
2791  [[fallthrough]];
2792  case MaterializationKind::Target:
2793  newMaterialization = converter->materializeTargetConversion(
2794  rewriter, op->getLoc(), outputType, inputOperands);
2795  break;
2796  case MaterializationKind::Source:
2797  newMaterialization = converter->materializeSourceConversion(
2798  rewriter, op->getLoc(), outputType, inputOperands);
2799  break;
2800  }
2801  if (newMaterialization) {
2802  assert(newMaterialization.getType() == outputType &&
2803  "materialization callback produced value of incorrect type");
2804  replaceMaterialization(rewriterImpl, opResult, newMaterialization,
2805  inverseMapping);
2806  return success();
2807  }
2808  }
2809 
2811  << "failed to legalize unresolved materialization "
2812  "from ("
2813  << inputOperands.getTypes() << ") to " << outputType
2814  << " that remained live after conversion";
2815  if (Operation *liveUser = findLiveUser(op->getUsers())) {
2816  diag.attachNote(liveUser->getLoc())
2817  << "see existing live user here: " << *liveUser;
2818  }
2819  return failure();
2820 }
2821 
2822 LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2823  ConversionPatternRewriter &rewriter,
2824  ConversionPatternRewriterImpl &rewriterImpl,
2825  std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
2826  inverseMapping = rewriterImpl.mapping.getInverse();
2827 
2828  // As an initial step, compute all of the inserted materializations that we
2829  // expect to persist beyond the conversion process.
2831  SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
2832  computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
2833  *inverseMapping, necessaryMaterializations);
2834 
2835  // Once computed, legalize any necessary materializations.
2836  for (auto *mat : necessaryMaterializations) {
2838  *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
2839  return failure();
2840  }
2841  return success();
2842 }
2843 
2844 LogicalResult OperationConverter::legalizeErasedResult(
2845  Operation *op, OpResult result,
2846  ConversionPatternRewriterImpl &rewriterImpl) {
2847  // If the operation result was replaced with null, all of the uses of this
2848  // value should be replaced.
2849  auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2850  return rewriterImpl.isOpIgnored(user);
2851  });
2852  if (liveUserIt != result.user_end()) {
2853  InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2854  << op->getName() << "' marked as erased";
2855  diag.attachNote(liveUserIt->getLoc())
2856  << "found live user of result #" << result.getResultNumber() << ": "
2857  << *liveUserIt;
2858  return failure();
2859  }
2860  return success();
2861 }
2862 
2863 /// Finds a user of the given value, or of any other value that the given value
2864 /// replaced, that was not replaced in the conversion process.
2866  Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2867  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2868  SmallVector<Value> worklist(1, initialValue);
2869  while (!worklist.empty()) {
2870  Value value = worklist.pop_back_val();
2871 
2872  // Walk the users of this value to see if there are any live users that
2873  // weren't replaced during conversion.
2874  auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2875  return rewriterImpl.isOpIgnored(user);
2876  });
2877  if (liveUserIt != value.user_end())
2878  return *liveUserIt;
2879  auto mapIt = inverseMapping.find(value);
2880  if (mapIt != inverseMapping.end())
2881  worklist.append(mapIt->second);
2882  }
2883  return nullptr;
2884 }
2885 
2886 LogicalResult OperationConverter::legalizeChangedResultType(
2887  Operation *op, OpResult result, Value newValue,
2888  const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2889  ConversionPatternRewriterImpl &rewriterImpl,
2890  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2891  Operation *liveUser =
2892  findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2893  if (!liveUser)
2894  return success();
2895 
2896  // Functor used to emit a conversion error for a failed materialization.
2897  auto emitConversionError = [&] {
2899  << "failed to materialize conversion for result #"
2900  << result.getResultNumber() << " of operation '"
2901  << op->getName()
2902  << "' that remained live after conversion";
2903  diag.attachNote(liveUser->getLoc())
2904  << "see existing live user here: " << *liveUser;
2905  return failure();
2906  };
2907 
2908  // If the replacement has a type converter, attempt to materialize a
2909  // conversion back to the original type.
2910  if (!replConverter)
2911  return emitConversionError();
2912 
2913  // Materialize a conversion for this live result value.
2914  Type resultType = result.getType();
2915  Value convertedValue = replConverter->materializeSourceConversion(
2916  rewriter, op->getLoc(), resultType, newValue);
2917  if (!convertedValue)
2918  return emitConversionError();
2919 
2920  rewriterImpl.mapping.map(result, convertedValue);
2921  return success();
2922 }
2923 
2924 //===----------------------------------------------------------------------===//
2925 // Type Conversion
2926 //===----------------------------------------------------------------------===//
2927 
2929  ArrayRef<Type> types) {
2930  assert(!types.empty() && "expected valid types");
2931  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2932  addInputs(types);
2933 }
2934 
2936  assert(!types.empty() &&
2937  "1->0 type remappings don't need to be added explicitly");
2938  argTypes.append(types.begin(), types.end());
2939 }
2940 
2941 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2942  unsigned newInputNo,
2943  unsigned newInputCount) {
2944  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2945  assert(newInputCount != 0 && "expected valid input count");
2946  remappedInputs[origInputNo] =
2947  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2948 }
2949 
2951  Value replacementValue) {
2952  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2953  remappedInputs[origInputNo] =
2954  InputMapping{origInputNo, /*size=*/0, replacementValue};
2955 }
2956 
2958  SmallVectorImpl<Type> &results) const {
2959  {
2960  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2961  std::defer_lock);
2963  cacheReadLock.lock();
2964  auto existingIt = cachedDirectConversions.find(t);
2965  if (existingIt != cachedDirectConversions.end()) {
2966  if (existingIt->second)
2967  results.push_back(existingIt->second);
2968  return success(existingIt->second != nullptr);
2969  }
2970  auto multiIt = cachedMultiConversions.find(t);
2971  if (multiIt != cachedMultiConversions.end()) {
2972  results.append(multiIt->second.begin(), multiIt->second.end());
2973  return success();
2974  }
2975  }
2976  // Walk the added converters in reverse order to apply the most recently
2977  // registered first.
2978  size_t currentCount = results.size();
2979 
2980  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2981  std::defer_lock);
2982 
2983  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2984  if (std::optional<LogicalResult> result = converter(t, results)) {
2986  cacheWriteLock.lock();
2987  if (!succeeded(*result)) {
2988  cachedDirectConversions.try_emplace(t, nullptr);
2989  return failure();
2990  }
2991  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2992  if (newTypes.size() == 1)
2993  cachedDirectConversions.try_emplace(t, newTypes.front());
2994  else
2995  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2996  return success();
2997  }
2998  }
2999  return failure();
3000 }
3001 
3003  // Use the multi-type result version to convert the type.
3004  SmallVector<Type, 1> results;
3005  if (failed(convertType(t, results)))
3006  return nullptr;
3007 
3008  // Check to ensure that only one type was produced.
3009  return results.size() == 1 ? results.front() : nullptr;
3010 }
3011 
3012 LogicalResult
3014  SmallVectorImpl<Type> &results) const {
3015  for (Type type : types)
3016  if (failed(convertType(type, results)))
3017  return failure();
3018  return success();
3019 }
3020 
3021 bool TypeConverter::isLegal(Type type) const {
3022  return convertType(type) == type;
3023 }
3025  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
3026 }
3027 
3028 bool TypeConverter::isLegal(Region *region) const {
3029  return llvm::all_of(*region, [this](Block &block) {
3030  return isLegal(block.getArgumentTypes());
3031  });
3032 }
3033 
3034 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3035  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
3036 }
3037 
3038 LogicalResult
3040  SignatureConversion &result) const {
3041  // Try to convert the given input type.
3042  SmallVector<Type, 1> convertedTypes;
3043  if (failed(convertType(type, convertedTypes)))
3044  return failure();
3045 
3046  // If this argument is being dropped, there is nothing left to do.
3047  if (convertedTypes.empty())
3048  return success();
3049 
3050  // Otherwise, add the new inputs.
3051  result.addInputs(inputNo, convertedTypes);
3052  return success();
3053 }
3054 LogicalResult
3056  SignatureConversion &result,
3057  unsigned origInputOffset) const {
3058  for (unsigned i = 0, e = types.size(); i != e; ++i)
3059  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3060  return failure();
3061  return success();
3062 }
3063 
3064 Value TypeConverter::materializeConversion(
3065  ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
3066  Location loc, Type resultType, ValueRange inputs) const {
3067  for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
3068  if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
3069  return *result;
3070  return nullptr;
3071 }
3072 
3073 std::optional<TypeConverter::SignatureConversion>
3075  SignatureConversion conversion(block->getNumArguments());
3076  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3077  return std::nullopt;
3078  return conversion;
3079 }
3080 
3081 //===----------------------------------------------------------------------===//
3082 // Type attribute conversion
3083 //===----------------------------------------------------------------------===//
3086  return AttributeConversionResult(attr, resultTag);
3087 }
3088 
3091  return AttributeConversionResult(nullptr, naTag);
3092 }
3093 
3096  return AttributeConversionResult(nullptr, abortTag);
3097 }
3098 
3100  return impl.getInt() == resultTag;
3101 }
3102 
3104  return impl.getInt() == naTag;
3105 }
3106 
3108  return impl.getInt() == abortTag;
3109 }
3110 
3112  assert(hasResult() && "Cannot get result from N/A or abort");
3113  return impl.getPointer();
3114 }
3115 
3116 std::optional<Attribute>
3118  for (const TypeAttributeConversionCallbackFn &fn :
3119  llvm::reverse(typeAttributeConversions)) {
3120  AttributeConversionResult res = fn(type, attr);
3121  if (res.hasResult())
3122  return res.getResult();
3123  if (res.isAbort())
3124  return std::nullopt;
3125  }
3126  return std::nullopt;
3127 }
3128 
3129 //===----------------------------------------------------------------------===//
3130 // FunctionOpInterfaceSignatureConversion
3131 //===----------------------------------------------------------------------===//
3132 
3133 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3134  const TypeConverter &typeConverter,
3135  ConversionPatternRewriter &rewriter) {
3136  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3137  if (!type)
3138  return failure();
3139 
3140  // Convert the original function types.
3141  TypeConverter::SignatureConversion result(type.getNumInputs());
3142  SmallVector<Type, 1> newResults;
3143  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3144  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3145  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3146  typeConverter, &result)))
3147  return failure();
3148 
3149  // Update the function signature in-place.
3150  auto newType = FunctionType::get(rewriter.getContext(),
3151  result.getConvertedTypes(), newResults);
3152 
3153  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3154 
3155  return success();
3156 }
3157 
3158 /// Create a default conversion pattern that rewrites the type signature of a
3159 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3160 /// represent their type.
3161 namespace {
3162 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3163  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3164  MLIRContext *ctx,
3165  const TypeConverter &converter)
3166  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3167 
3168  LogicalResult
3169  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3170  ConversionPatternRewriter &rewriter) const override {
3171  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3172  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3173  }
3174 };
3175 
3176 struct AnyFunctionOpInterfaceSignatureConversion
3177  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3179 
3180  LogicalResult
3181  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3182  ConversionPatternRewriter &rewriter) const override {
3183  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3184  }
3185 };
3186 } // namespace
3187 
3188 FailureOr<Operation *>
3190  const TypeConverter &converter,
3191  ConversionPatternRewriter &rewriter) {
3192  assert(op && "Invalid op");
3193  Location loc = op->getLoc();
3194  if (converter.isLegal(op))
3195  return rewriter.notifyMatchFailure(loc, "op already legal");
3196 
3197  OperationState newOp(loc, op->getName());
3198  newOp.addOperands(operands);
3199 
3200  SmallVector<Type> newResultTypes;
3201  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3202  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3203 
3204  newOp.addTypes(newResultTypes);
3205  newOp.addAttributes(op->getAttrs());
3206  return rewriter.create(newOp);
3207 }
3208 
3210  StringRef functionLikeOpName, RewritePatternSet &patterns,
3211  const TypeConverter &converter) {
3212  patterns.add<FunctionOpInterfaceSignatureConversion>(
3213  functionLikeOpName, patterns.getContext(), converter);
3214 }
3215 
3217  RewritePatternSet &patterns, const TypeConverter &converter) {
3218  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3219  converter, patterns.getContext());
3220 }
3221 
3222 //===----------------------------------------------------------------------===//
3223 // ConversionTarget
3224 //===----------------------------------------------------------------------===//
3225 
3227  LegalizationAction action) {
3228  legalOperations[op].action = action;
3229 }
3230 
3232  LegalizationAction action) {
3233  for (StringRef dialect : dialectNames)
3234  legalDialects[dialect] = action;
3235 }
3236 
3238  -> std::optional<LegalizationAction> {
3239  std::optional<LegalizationInfo> info = getOpInfo(op);
3240  return info ? info->action : std::optional<LegalizationAction>();
3241 }
3242 
3244  -> std::optional<LegalOpDetails> {
3245  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3246  if (!info)
3247  return std::nullopt;
3248 
3249  // Returns true if this operation instance is known to be legal.
3250  auto isOpLegal = [&] {
3251  // Handle dynamic legality either with the provided legality function.
3252  if (info->action == LegalizationAction::Dynamic) {
3253  std::optional<bool> result = info->legalityFn(op);
3254  if (result)
3255  return *result;
3256  }
3257 
3258  // Otherwise, the operation is only legal if it was marked 'Legal'.
3259  return info->action == LegalizationAction::Legal;
3260  };
3261  if (!isOpLegal())
3262  return std::nullopt;
3263 
3264  // This operation is legal, compute any additional legality information.
3265  LegalOpDetails legalityDetails;
3266  if (info->isRecursivelyLegal) {
3267  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3268  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3269  legalityDetails.isRecursivelyLegal =
3270  legalityFnIt->second(op).value_or(true);
3271  } else {
3272  legalityDetails.isRecursivelyLegal = true;
3273  }
3274  }
3275  return legalityDetails;
3276 }
3277 
3279  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3280  if (!info)
3281  return false;
3282 
3283  if (info->action == LegalizationAction::Dynamic) {
3284  std::optional<bool> result = info->legalityFn(op);
3285  if (!result)
3286  return false;
3287 
3288  return !(*result);
3289  }
3290 
3291  return info->action == LegalizationAction::Illegal;
3292 }
3293 
3297  if (!oldCallback)
3298  return newCallback;
3299 
3300  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3301  Operation *op) -> std::optional<bool> {
3302  if (std::optional<bool> result = newCl(op))
3303  return *result;
3304 
3305  return oldCl(op);
3306  };
3307  return chain;
3308 }
3309 
3310 void ConversionTarget::setLegalityCallback(
3311  OperationName name, const DynamicLegalityCallbackFn &callback) {
3312  assert(callback && "expected valid legality callback");
3313  auto *infoIt = legalOperations.find(name);
3314  assert(infoIt != legalOperations.end() &&
3315  infoIt->second.action == LegalizationAction::Dynamic &&
3316  "expected operation to already be marked as dynamically legal");
3317  infoIt->second.legalityFn =
3318  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3319 }
3320 
3322  OperationName name, const DynamicLegalityCallbackFn &callback) {
3323  auto *infoIt = legalOperations.find(name);
3324  assert(infoIt != legalOperations.end() &&
3325  infoIt->second.action != LegalizationAction::Illegal &&
3326  "expected operation to already be marked as legal");
3327  infoIt->second.isRecursivelyLegal = true;
3328  if (callback)
3329  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3330  std::move(opRecursiveLegalityFns[name]), callback);
3331  else
3332  opRecursiveLegalityFns.erase(name);
3333 }
3334 
3335 void ConversionTarget::setLegalityCallback(
3336  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3337  assert(callback && "expected valid legality callback");
3338  for (StringRef dialect : dialects)
3339  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3340  std::move(dialectLegalityFns[dialect]), callback);
3341 }
3342 
3343 void ConversionTarget::setLegalityCallback(
3344  const DynamicLegalityCallbackFn &callback) {
3345  assert(callback && "expected valid legality callback");
3346  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3347 }
3348 
3349 auto ConversionTarget::getOpInfo(OperationName op) const
3350  -> std::optional<LegalizationInfo> {
3351  // Check for info for this specific operation.
3352  const auto *it = legalOperations.find(op);
3353  if (it != legalOperations.end())
3354  return it->second;
3355  // Check for info for the parent dialect.
3356  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3357  if (dialectIt != legalDialects.end()) {
3358  DynamicLegalityCallbackFn callback;
3359  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3360  if (dialectFn != dialectLegalityFns.end())
3361  callback = dialectFn->second;
3362  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3363  callback};
3364  }
3365  // Otherwise, check if we mark unknown operations as dynamic.
3366  if (unknownLegalityFn)
3367  return LegalizationInfo{LegalizationAction::Dynamic,
3368  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3369  return std::nullopt;
3370 }
3371 
3372 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3373 //===----------------------------------------------------------------------===//
3374 // PDL Configuration
3375 //===----------------------------------------------------------------------===//
3376 
3378  auto &rewriterImpl =
3379  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3380  rewriterImpl.currentTypeConverter = getTypeConverter();
3381 }
3382 
3384  auto &rewriterImpl =
3385  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3386  rewriterImpl.currentTypeConverter = nullptr;
3387 }
3388 
3389 /// Remap the given value using the rewriter and the type converter in the
3390 /// provided config.
3391 static FailureOr<SmallVector<Value>>
3393  SmallVector<Value> mappedValues;
3394  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3395  return failure();
3396  return std::move(mappedValues);
3397 }
3398 
3400  patterns.getPDLPatterns().registerRewriteFunction(
3401  "convertValue",
3402  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3403  auto results = pdllConvertValues(
3404  static_cast<ConversionPatternRewriter &>(rewriter), value);
3405  if (failed(results))
3406  return failure();
3407  return results->front();
3408  });
3409  patterns.getPDLPatterns().registerRewriteFunction(
3410  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3411  return pdllConvertValues(
3412  static_cast<ConversionPatternRewriter &>(rewriter), values);
3413  });
3414  patterns.getPDLPatterns().registerRewriteFunction(
3415  "convertType",
3416  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3417  auto &rewriterImpl =
3418  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3419  if (const TypeConverter *converter =
3420  rewriterImpl.currentTypeConverter) {
3421  if (Type newType = converter->convertType(type))
3422  return newType;
3423  return failure();
3424  }
3425  return type;
3426  });
3427  patterns.getPDLPatterns().registerRewriteFunction(
3428  "convertTypes",
3429  [](PatternRewriter &rewriter,
3430  TypeRange types) -> FailureOr<SmallVector<Type>> {
3431  auto &rewriterImpl =
3432  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3433  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3434  if (!converter)
3435  return SmallVector<Type>(types);
3436 
3437  SmallVector<Type> remappedTypes;
3438  if (failed(converter->convertTypes(types, remappedTypes)))
3439  return failure();
3440  return std::move(remappedTypes);
3441  });
3442 }
3443 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3444 
3445 //===----------------------------------------------------------------------===//
3446 // Op Conversion Entry Points
3447 //===----------------------------------------------------------------------===//
3448 
3449 //===----------------------------------------------------------------------===//
3450 // Partial Conversion
3451 
3453  ArrayRef<Operation *> ops, const ConversionTarget &target,
3454  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3455  OperationConverter opConverter(target, patterns, config,
3456  OpConversionMode::Partial);
3457  return opConverter.convertOperations(ops);
3458 }
3459 LogicalResult
3461  const FrozenRewritePatternSet &patterns,
3462  ConversionConfig config) {
3463  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3464 }
3465 
3466 //===----------------------------------------------------------------------===//
3467 // Full Conversion
3468 
3470  const ConversionTarget &target,
3471  const FrozenRewritePatternSet &patterns,
3472  ConversionConfig config) {
3473  OperationConverter opConverter(target, patterns, config,
3474  OpConversionMode::Full);
3475  return opConverter.convertOperations(ops);
3476 }
3478  const ConversionTarget &target,
3479  const FrozenRewritePatternSet &patterns,
3480  ConversionConfig config) {
3481  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3482 }
3483 
3484 //===----------------------------------------------------------------------===//
3485 // Analysis Conversion
3486 
3489  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3490  OperationConverter opConverter(target, patterns, config,
3491  OpConversionMode::Analysis);
3492  return opConverter.convertOperations(ops);
3493 }
3494 LogicalResult
3496  const FrozenRewritePatternSet &patterns,
3497  ConversionConfig config) {
3498  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3499 }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(UnresolvedMaterializationRewrite &mat, DenseMap< Operation *, UnresolvedMaterializationRewrite * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Legalize the given unresolved materialization.
static RewriteTy * findSingleRewrite(R &&rewrites, Block *block)
Find the single rewrite object of the specified type and block among the given rewrites.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, ResultRange matResults, ValueRange values, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Replace the results of a materialization operation with the given values.
static void computeNecessaryMaterializations(DenseMap< Operation *, UnresolvedMaterializationRewrite * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping, SetVector< UnresolvedMaterializationRewrite * > &necessaryMaterializations)
Compute all of the unresolved materializations that will persist beyond the conversion process,...
static bool hasRewrite(R &&rewrites, Operation *op)
Return "true" if there is an operation rewrite that matches the specified rewrite type and operation ...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
bool empty()
Definition: Block.h:146
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:93
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator end()
Definition: Block.h:142
iterator begin()
Definition: Block.h:141
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:330
Block::iterator getPoint() const
Definition: Builders.h:343
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:340
Block * getBlock() const
Definition: Builders.h:342
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:323
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:483
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
OpResult getOpResult(unsigned idx)
Definition: Operation.h:416
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:605
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:892
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:281
MLIRContext * getContext() const
Definition: PatternMatch.h:823
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_iterator user_end() const
Definition: Value.h:227
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition: Argument.h:64
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a single value that remaps an existing signature input...
A rewriter that keeps track of erased ops and blocks.
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
DenseSet< void * > erased
Pointers to all erased operations and blocks.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
Value buildUnresolvedMaterialization(MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, const TypeConverter *converter)
Build an unresolved materialization operation given an output type and set of input operands.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ValueRange newValues)
Notifies that an op is about to be replaced with the given values.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, const TypeConverter *converter)
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.